!pip install flax transformers ftfy
!pip install diffusers==0.9.0
9 Stable Diffusion in JAX / Flax !
Grateful to share this notebook from Hugging Face. Related blog post here.
🤗 Hugging Face Diffusers supports Flax since version 0.5.1
! This allows for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform.
This notebook shows how to run inference using JAX / Flax. If you want more details about how Stable Diffusion works or want to run it in GPU, please refer to this Colab notebook.
First, make sure you are using a TPU backend. If you are running this notebook in Colab, select Runtime
in the menu above, then select the option “Change runtime type” and then select TPU
under the Hardware accelerator
setting.
Note that JAX is not exclusive to TPUs, but it shines on that hardware because each TPU server has 8 TPU accelerators working in parallel.
9.1 Setup
import jax
= jax.device_count()
num_devices = jax.devices()[0].device_kind
device_type
print(f"Found {num_devices} JAX devices of type {device_type}.")
assert "TPU" in device_type, "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
Found 8 JAX devices of type Cloud TPU.
Then we import all the dependencies.
import numpy as np
import jax
import jax.numpy as jnp
from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline
9.2 Model Loading
Before using the model, you need to accept the model license in order to download and use the weights.
The license is designed to mitigate the potential harmful effects of such a powerful machine learning system. We request users to read the license entirely and carefully. Here we offer a summary:
You can’t use the model to deliberately produce nor share illegal or harmful outputs or content,
We claim no rights on the outputs you generate, you are free to use them and are accountable for their use which should not go against the provisions set in the license, and
You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users.
Flax weights are available in Hugging Face Hub as part of the Stable Diffusion repo. To use them, you need to be a registered user in Hugging Face Hub and use an access token for the code to work. You have two options to provide your access token:
- Use the
huggingface-cli login
command-line tool in your terminal and paste your token when prompted. It will be saved in a file in your computer. - Or use
notebook_login()
in a notebook, which does the same thing.
The following cell will present a login interface unless you’ve already authenticated before in this computer. You’ll need to paste your access token.
if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()
Login successful
Your token has been saved to /root/.huggingface/token
TPU devices support bfloat16
, an efficient half-float type. We’ll use it for our tests, but you can also use float32
to use full precision instead.
= jnp.bfloat16 dtype
Flax is a functional framework, so models are stateless and parameters are stored outside them. Loading the pre-trained Flax pipeline will return both the pipeline itself and the model weights (or parameters). We are using a bf16
version of the weights, which leads to type warnings that you can safely ignore.
= FlaxStableDiffusionPipeline.from_pretrained(
pipeline, params "CompVis/stable-diffusion-v1-4",
="bf16",
revision=dtype,
dtype )
Some of the weights of FlaxStableDiffusionSafetyChecker were initialized in bfloat16 precision from the model checkpoint at /root/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/safety_checker:
[('concept_embeds',), ('concept_embeds_weights',), ('special_care_embeds',), ('special_care_embeds_weights',), ('vision_model', 'vision_model', 'embeddings', 'class_embedding'), ('vision_model', 'vision_model', 'embeddings', 'patch_embedding', 'kernel'), ('vision_model', 'vision_model', 'embeddings', 'position_embedding', 'embedding'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'post_layernorm', 'bias'), ('vision_model', 'vision_model', 'post_layernorm', 'scale'), ('vision_model', 'vision_model', 'pre_layrnorm', 'bias'), ('vision_model', 'vision_model', 'pre_layrnorm', 'scale'), ('visual_projection', 'kernel')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.
Some of the weights of FlaxCLIPTextModel were initialized in bfloat16 precision from the model checkpoint at /root/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/text_encoder:
[('text_model', 'embeddings', 'position_embedding', 'embedding'), ('text_model', 'embeddings', 'token_embedding', 'embedding'), ('text_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'final_layer_norm', 'bias'), ('text_model', 'final_layer_norm', 'scale')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.
Some of the weights of FlaxAutoencoderKL were initialized in bfloat16 precision from the model checkpoint at /root/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/vae:
[('decoder', 'conv_in', 'bias'), ('decoder', 'conv_in', 'kernel'), ('decoder', 'conv_norm_out', 'bias'), ('decoder', 'conv_norm_out', 'scale'), ('decoder', 'conv_out', 'bias'), ('decoder', 'conv_out', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'group_norm', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'group_norm', 'scale'), ('decoder', 'mid_block', 'attentions_0', 'key', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'key', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'proj_attn', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'proj_attn', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'query', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'query', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'value', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'value', 'kernel'), ('decoder', 'mid_block', 'resnets_0', 'conv1', 'bias'), ('decoder', 'mid_block', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'mid_block', 'resnets_0', 'conv2', 'bias'), ('decoder', 'mid_block', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'mid_block', 'resnets_0', 'norm1', 'bias'), ('decoder', 'mid_block', 'resnets_0', 'norm1', 'scale'), ('decoder', 'mid_block', 'resnets_0', 'norm2', 'bias'), ('decoder', 'mid_block', 'resnets_0', 'norm2', 'scale'), ('decoder', 'mid_block', 'resnets_1', 'conv1', 'bias'), ('decoder', 'mid_block', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'mid_block', 'resnets_1', 'conv2', 'bias'), ('decoder', 'mid_block', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'mid_block', 'resnets_1', 'norm1', 'bias'), ('decoder', 'mid_block', 'resnets_1', 'norm1', 'scale'), ('decoder', 'mid_block', 'resnets_1', 'norm2', 'bias'), ('decoder', 'mid_block', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_0', 'resnets_0', 'conv1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_0', 'conv2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_0', 'norm1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_0', 'norm1', 'scale'), ('decoder', 'up_blocks_0', 'resnets_0', 'norm2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_0', 'norm2', 'scale'), ('decoder', 'up_blocks_0', 'resnets_1', 'conv1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_1', 'conv2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_1', 'norm1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_1', 'norm1', 'scale'), ('decoder', 'up_blocks_0', 'resnets_1', 'norm2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_0', 'resnets_2', 'conv1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_2', 'conv1', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_2', 'conv2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_2', 'conv2', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_2', 'norm1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_2', 'norm1', 'scale'), ('decoder', 'up_blocks_0', 'resnets_2', 'norm2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_2', 'norm2', 'scale'), ('decoder', 'up_blocks_0', 'upsamplers_0', 'conv', 'bias'), ('decoder', 'up_blocks_0', 'upsamplers_0', 'conv', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_0', 'conv1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_0', 'conv2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_0', 'norm1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_0', 'norm1', 'scale'), ('decoder', 'up_blocks_1', 'resnets_0', 'norm2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_0', 'norm2', 'scale'), ('decoder', 'up_blocks_1', 'resnets_1', 'conv1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_1', 'conv2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_1', 'norm1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_1', 'norm1', 'scale'), ('decoder', 'up_blocks_1', 'resnets_1', 'norm2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_1', 'resnets_2', 'conv1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_2', 'conv1', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_2', 'conv2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_2', 'conv2', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_2', 'norm1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_2', 'norm1', 'scale'), ('decoder', 'up_blocks_1', 'resnets_2', 'norm2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_2', 'norm2', 'scale'), ('decoder', 'up_blocks_1', 'upsamplers_0', 'conv', 'bias'), ('decoder', 'up_blocks_1', 'upsamplers_0', 'conv', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv_shortcut', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv_shortcut', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_0', 'norm1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'norm1', 'scale'), ('decoder', 'up_blocks_2', 'resnets_0', 'norm2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'norm2', 'scale'), ('decoder', 'up_blocks_2', 'resnets_1', 'conv1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_1', 'conv2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_1', 'norm1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_1', 'norm1', 'scale'), ('decoder', 'up_blocks_2', 'resnets_1', 'norm2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_2', 'resnets_2', 'conv1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_2', 'conv1', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_2', 'conv2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_2', 'conv2', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_2', 'norm1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_2', 'norm1', 'scale'), ('decoder', 'up_blocks_2', 'resnets_2', 'norm2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_2', 'norm2', 'scale'), ('decoder', 'up_blocks_2', 'upsamplers_0', 'conv', 'bias'), ('decoder', 'up_blocks_2', 'upsamplers_0', 'conv', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv_shortcut', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv_shortcut', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_0', 'norm1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'norm1', 'scale'), ('decoder', 'up_blocks_3', 'resnets_0', 'norm2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'norm2', 'scale'), ('decoder', 'up_blocks_3', 'resnets_1', 'conv1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_1', 'conv2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_1', 'norm1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_1', 'norm1', 'scale'), ('decoder', 'up_blocks_3', 'resnets_1', 'norm2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_3', 'resnets_2', 'conv1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_2', 'conv1', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_2', 'conv2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_2', 'conv2', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_2', 'norm1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_2', 'norm1', 'scale'), ('decoder', 'up_blocks_3', 'resnets_2', 'norm2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_2', 'norm2', 'scale'), ('encoder', 'conv_in', 'bias'), ('encoder', 'conv_in', 'kernel'), ('encoder', 'conv_norm_out', 'bias'), ('encoder', 'conv_norm_out', 'scale'), ('encoder', 'conv_out', 'bias'), ('encoder', 'conv_out', 'kernel'), ('encoder', 'down_blocks_0', 'downsamplers_0', 'conv', 'bias'), ('encoder', 'down_blocks_0', 'downsamplers_0', 'conv', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_0', 'conv1', 'bias'), ('encoder', 'down_blocks_0', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_0', 'conv2', 'bias'), ('encoder', 'down_blocks_0', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_0', 'norm1', 'bias'), ('encoder', 'down_blocks_0', 'resnets_0', 'norm1', 'scale'), ('encoder', 'down_blocks_0', 'resnets_0', 'norm2', 'bias'), ('encoder', 'down_blocks_0', 'resnets_0', 'norm2', 'scale'), ('encoder', 'down_blocks_0', 'resnets_1', 'conv1', 'bias'), ('encoder', 'down_blocks_0', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_1', 'conv2', 'bias'), ('encoder', 'down_blocks_0', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_1', 'norm1', 'bias'), ('encoder', 'down_blocks_0', 'resnets_1', 'norm1', 'scale'), ('encoder', 'down_blocks_0', 'resnets_1', 'norm2', 'bias'), ('encoder', 'down_blocks_0', 'resnets_1', 'norm2', 'scale'), ('encoder', 'down_blocks_1', 'downsamplers_0', 'conv', 'bias'), ('encoder', 'down_blocks_1', 'downsamplers_0', 'conv', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv1', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv2', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv_shortcut', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv_shortcut', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_0', 'norm1', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'norm1', 'scale'), ('encoder', 'down_blocks_1', 'resnets_0', 'norm2', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'norm2', 'scale'), ('encoder', 'down_blocks_1', 'resnets_1', 'conv1', 'bias'), ('encoder', 'down_blocks_1', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_1', 'conv2', 'bias'), ('encoder', 'down_blocks_1', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_1', 'norm1', 'bias'), ('encoder', 'down_blocks_1', 'resnets_1', 'norm1', 'scale'), ('encoder', 'down_blocks_1', 'resnets_1', 'norm2', 'bias'), ('encoder', 'down_blocks_1', 'resnets_1', 'norm2', 'scale'), ('encoder', 'down_blocks_2', 'downsamplers_0', 'conv', 'bias'), ('encoder', 'down_blocks_2', 'downsamplers_0', 'conv', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv1', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv2', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv_shortcut', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv_shortcut', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_0', 'norm1', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'norm1', 'scale'), ('encoder', 'down_blocks_2', 'resnets_0', 'norm2', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'norm2', 'scale'), ('encoder', 'down_blocks_2', 'resnets_1', 'conv1', 'bias'), ('encoder', 'down_blocks_2', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_1', 'conv2', 'bias'), ('encoder', 'down_blocks_2', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_1', 'norm1', 'bias'), ('encoder', 'down_blocks_2', 'resnets_1', 'norm1', 'scale'), ('encoder', 'down_blocks_2', 'resnets_1', 'norm2', 'bias'), ('encoder', 'down_blocks_2', 'resnets_1', 'norm2', 'scale'), ('encoder', 'down_blocks_3', 'resnets_0', 'conv1', 'bias'), ('encoder', 'down_blocks_3', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'down_blocks_3', 'resnets_0', 'conv2', 'bias'), ('encoder', 'down_blocks_3', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'down_blocks_3', 'resnets_0', 'norm1', 'bias'), ('encoder', 'down_blocks_3', 'resnets_0', 'norm1', 'scale'), ('encoder', 'down_blocks_3', 'resnets_0', 'norm2', 'bias'), ('encoder', 'down_blocks_3', 'resnets_0', 'norm2', 'scale'), ('encoder', 'down_blocks_3', 'resnets_1', 'conv1', 'bias'), ('encoder', 'down_blocks_3', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'down_blocks_3', 'resnets_1', 'conv2', 'bias'), ('encoder', 'down_blocks_3', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'down_blocks_3', 'resnets_1', 'norm1', 'bias'), ('encoder', 'down_blocks_3', 'resnets_1', 'norm1', 'scale'), ('encoder', 'down_blocks_3', 'resnets_1', 'norm2', 'bias'), ('encoder', 'down_blocks_3', 'resnets_1', 'norm2', 'scale'), ('encoder', 'mid_block', 'attentions_0', 'group_norm', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'group_norm', 'scale'), ('encoder', 'mid_block', 'attentions_0', 'key', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'key', 'kernel'), ('encoder', 'mid_block', 'attentions_0', 'proj_attn', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'proj_attn', 'kernel'), ('encoder', 'mid_block', 'attentions_0', 'query', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'query', 'kernel'), ('encoder', 'mid_block', 'attentions_0', 'value', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'value', 'kernel'), ('encoder', 'mid_block', 'resnets_0', 'conv1', 'bias'), ('encoder', 'mid_block', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'mid_block', 'resnets_0', 'conv2', 'bias'), ('encoder', 'mid_block', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'mid_block', 'resnets_0', 'norm1', 'bias'), ('encoder', 'mid_block', 'resnets_0', 'norm1', 'scale'), ('encoder', 'mid_block', 'resnets_0', 'norm2', 'bias'), ('encoder', 'mid_block', 'resnets_0', 'norm2', 'scale'), ('encoder', 'mid_block', 'resnets_1', 'conv1', 'bias'), ('encoder', 'mid_block', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'mid_block', 'resnets_1', 'conv2', 'bias'), ('encoder', 'mid_block', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'mid_block', 'resnets_1', 'norm1', 'bias'), ('encoder', 'mid_block', 'resnets_1', 'norm1', 'scale'), ('encoder', 'mid_block', 'resnets_1', 'norm2', 'bias'), ('encoder', 'mid_block', 'resnets_1', 'norm2', 'scale'), ('post_quant_conv', 'bias'), ('post_quant_conv', 'kernel'), ('quant_conv', 'bias'), ('quant_conv', 'kernel')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~ModelMixin.to_fp32`] for further information on how to do this.
Some of the weights of FlaxUNet2DConditionModel were initialized in bfloat16 precision from the model checkpoint at /root/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/unet:
[('conv_in', 'bias'), ('conv_in', 'kernel'), ('conv_norm_out', 'bias'), ('conv_norm_out', 'scale'), ('conv_out', 'bias'), ('conv_out', 'kernel'), ('down_blocks_0', 'attentions_0', 'norm', 'bias'), ('down_blocks_0', 'attentions_0', 'norm', 'scale'), ('down_blocks_0', 'attentions_0', 'proj_in', 'bias'), ('down_blocks_0', 'attentions_0', 'proj_in', 'kernel'), ('down_blocks_0', 'attentions_0', 'proj_out', 'bias'), ('down_blocks_0', 'attentions_0', 'proj_out', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_0', 'attentions_1', 'norm', 'bias'), ('down_blocks_0', 'attentions_1', 'norm', 'scale'), ('down_blocks_0', 'attentions_1', 'proj_in', 'bias'), ('down_blocks_0', 'attentions_1', 'proj_in', 'kernel'), ('down_blocks_0', 'attentions_1', 'proj_out', 'bias'), ('down_blocks_0', 'attentions_1', 'proj_out', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_0', 'downsamplers_0', 'conv', 'bias'), ('down_blocks_0', 'downsamplers_0', 'conv', 'kernel'), ('down_blocks_0', 'resnets_0', 'conv1', 'bias'), ('down_blocks_0', 'resnets_0', 'conv1', 'kernel'), ('down_blocks_0', 'resnets_0', 'conv2', 'bias'), ('down_blocks_0', 'resnets_0', 'conv2', 'kernel'), ('down_blocks_0', 'resnets_0', 'norm1', 'bias'), ('down_blocks_0', 'resnets_0', 'norm1', 'scale'), ('down_blocks_0', 'resnets_0', 'norm2', 'bias'), ('down_blocks_0', 'resnets_0', 'norm2', 'scale'), ('down_blocks_0', 'resnets_0', 'time_emb_proj', 'bias'), ('down_blocks_0', 'resnets_0', 'time_emb_proj', 'kernel'), ('down_blocks_0', 'resnets_1', 'conv1', 'bias'), ('down_blocks_0', 'resnets_1', 'conv1', 'kernel'), ('down_blocks_0', 'resnets_1', 'conv2', 'bias'), ('down_blocks_0', 'resnets_1', 'conv2', 'kernel'), ('down_blocks_0', 'resnets_1', 'norm1', 'bias'), ('down_blocks_0', 'resnets_1', 'norm1', 'scale'), ('down_blocks_0', 'resnets_1', 'norm2', 'bias'), ('down_blocks_0', 'resnets_1', 'norm2', 'scale'), ('down_blocks_0', 'resnets_1', 'time_emb_proj', 'bias'), ('down_blocks_0', 'resnets_1', 'time_emb_proj', 'kernel'), ('down_blocks_1', 'attentions_0', 'norm', 'bias'), ('down_blocks_1', 'attentions_0', 'norm', 'scale'), ('down_blocks_1', 'attentions_0', 'proj_in', 'bias'), ('down_blocks_1', 'attentions_0', 'proj_in', 'kernel'), ('down_blocks_1', 'attentions_0', 'proj_out', 'bias'), ('down_blocks_1', 'attentions_0', 'proj_out', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_1', 'attentions_1', 'norm', 'bias'), ('down_blocks_1', 'attentions_1', 'norm', 'scale'), ('down_blocks_1', 'attentions_1', 'proj_in', 'bias'), ('down_blocks_1', 'attentions_1', 'proj_in', 'kernel'), ('down_blocks_1', 'attentions_1', 'proj_out', 'bias'), ('down_blocks_1', 'attentions_1', 'proj_out', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_1', 'downsamplers_0', 'conv', 'bias'), ('down_blocks_1', 'downsamplers_0', 'conv', 'kernel'), ('down_blocks_1', 'resnets_0', 'conv1', 'bias'), ('down_blocks_1', 'resnets_0', 'conv1', 'kernel'), ('down_blocks_1', 'resnets_0', 'conv2', 'bias'), ('down_blocks_1', 'resnets_0', 'conv2', 'kernel'), ('down_blocks_1', 'resnets_0', 'conv_shortcut', 'bias'), ('down_blocks_1', 'resnets_0', 'conv_shortcut', 'kernel'), ('down_blocks_1', 'resnets_0', 'norm1', 'bias'), ('down_blocks_1', 'resnets_0', 'norm1', 'scale'), ('down_blocks_1', 'resnets_0', 'norm2', 'bias'), ('down_blocks_1', 'resnets_0', 'norm2', 'scale'), ('down_blocks_1', 'resnets_0', 'time_emb_proj', 'bias'), ('down_blocks_1', 'resnets_0', 'time_emb_proj', 'kernel'), ('down_blocks_1', 'resnets_1', 'conv1', 'bias'), ('down_blocks_1', 'resnets_1', 'conv1', 'kernel'), ('down_blocks_1', 'resnets_1', 'conv2', 'bias'), ('down_blocks_1', 'resnets_1', 'conv2', 'kernel'), ('down_blocks_1', 'resnets_1', 'norm1', 'bias'), ('down_blocks_1', 'resnets_1', 'norm1', 'scale'), ('down_blocks_1', 'resnets_1', 'norm2', 'bias'), ('down_blocks_1', 'resnets_1', 'norm2', 'scale'), ('down_blocks_1', 'resnets_1', 'time_emb_proj', 'bias'), ('down_blocks_1', 'resnets_1', 'time_emb_proj', 'kernel'), ('down_blocks_2', 'attentions_0', 'norm', 'bias'), ('down_blocks_2', 'attentions_0', 'norm', 'scale'), ('down_blocks_2', 'attentions_0', 'proj_in', 'bias'), ('down_blocks_2', 'attentions_0', 'proj_in', 'kernel'), ('down_blocks_2', 'attentions_0', 'proj_out', 'bias'), ('down_blocks_2', 'attentions_0', 'proj_out', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_2', 'attentions_1', 'norm', 'bias'), ('down_blocks_2', 'attentions_1', 'norm', 'scale'), ('down_blocks_2', 'attentions_1', 'proj_in', 'bias'), ('down_blocks_2', 'attentions_1', 'proj_in', 'kernel'), ('down_blocks_2', 'attentions_1', 'proj_out', 'bias'), ('down_blocks_2', 'attentions_1', 'proj_out', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_2', 'downsamplers_0', 'conv', 'bias'), ('down_blocks_2', 'downsamplers_0', 'conv', 'kernel'), ('down_blocks_2', 'resnets_0', 'conv1', 'bias'), ('down_blocks_2', 'resnets_0', 'conv1', 'kernel'), ('down_blocks_2', 'resnets_0', 'conv2', 'bias'), ('down_blocks_2', 'resnets_0', 'conv2', 'kernel'), ('down_blocks_2', 'resnets_0', 'conv_shortcut', 'bias'), ('down_blocks_2', 'resnets_0', 'conv_shortcut', 'kernel'), ('down_blocks_2', 'resnets_0', 'norm1', 'bias'), ('down_blocks_2', 'resnets_0', 'norm1', 'scale'), ('down_blocks_2', 'resnets_0', 'norm2', 'bias'), ('down_blocks_2', 'resnets_0', 'norm2', 'scale'), ('down_blocks_2', 'resnets_0', 'time_emb_proj', 'bias'), ('down_blocks_2', 'resnets_0', 'time_emb_proj', 'kernel'), ('down_blocks_2', 'resnets_1', 'conv1', 'bias'), ('down_blocks_2', 'resnets_1', 'conv1', 'kernel'), ('down_blocks_2', 'resnets_1', 'conv2', 'bias'), ('down_blocks_2', 'resnets_1', 'conv2', 'kernel'), ('down_blocks_2', 'resnets_1', 'norm1', 'bias'), ('down_blocks_2', 'resnets_1', 'norm1', 'scale'), ('down_blocks_2', 'resnets_1', 'norm2', 'bias'), ('down_blocks_2', 'resnets_1', 'norm2', 'scale'), ('down_blocks_2', 'resnets_1', 'time_emb_proj', 'bias'), ('down_blocks_2', 'resnets_1', 'time_emb_proj', 'kernel'), ('down_blocks_3', 'resnets_0', 'conv1', 'bias'), ('down_blocks_3', 'resnets_0', 'conv1', 'kernel'), ('down_blocks_3', 'resnets_0', 'conv2', 'bias'), ('down_blocks_3', 'resnets_0', 'conv2', 'kernel'), ('down_blocks_3', 'resnets_0', 'norm1', 'bias'), ('down_blocks_3', 'resnets_0', 'norm1', 'scale'), ('down_blocks_3', 'resnets_0', 'norm2', 'bias'), ('down_blocks_3', 'resnets_0', 'norm2', 'scale'), ('down_blocks_3', 'resnets_0', 'time_emb_proj', 'bias'), ('down_blocks_3', 'resnets_0', 'time_emb_proj', 'kernel'), ('down_blocks_3', 'resnets_1', 'conv1', 'bias'), ('down_blocks_3', 'resnets_1', 'conv1', 'kernel'), ('down_blocks_3', 'resnets_1', 'conv2', 'bias'), ('down_blocks_3', 'resnets_1', 'conv2', 'kernel'), ('down_blocks_3', 'resnets_1', 'norm1', 'bias'), ('down_blocks_3', 'resnets_1', 'norm1', 'scale'), ('down_blocks_3', 'resnets_1', 'norm2', 'bias'), ('down_blocks_3', 'resnets_1', 'norm2', 'scale'), ('down_blocks_3', 'resnets_1', 'time_emb_proj', 'bias'), ('down_blocks_3', 'resnets_1', 'time_emb_proj', 'kernel'), ('mid_block', 'attentions_0', 'norm', 'bias'), ('mid_block', 'attentions_0', 'norm', 'scale'), ('mid_block', 'attentions_0', 'proj_in', 'bias'), ('mid_block', 'attentions_0', 'proj_in', 'kernel'), ('mid_block', 'attentions_0', 'proj_out', 'bias'), ('mid_block', 'attentions_0', 'proj_out', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('mid_block', 'resnets_0', 'conv1', 'bias'), ('mid_block', 'resnets_0', 'conv1', 'kernel'), ('mid_block', 'resnets_0', 'conv2', 'bias'), ('mid_block', 'resnets_0', 'conv2', 'kernel'), ('mid_block', 'resnets_0', 'norm1', 'bias'), ('mid_block', 'resnets_0', 'norm1', 'scale'), ('mid_block', 'resnets_0', 'norm2', 'bias'), ('mid_block', 'resnets_0', 'norm2', 'scale'), ('mid_block', 'resnets_0', 'time_emb_proj', 'bias'), ('mid_block', 'resnets_0', 'time_emb_proj', 'kernel'), ('mid_block', 'resnets_1', 'conv1', 'bias'), ('mid_block', 'resnets_1', 'conv1', 'kernel'), ('mid_block', 'resnets_1', 'conv2', 'bias'), ('mid_block', 'resnets_1', 'conv2', 'kernel'), ('mid_block', 'resnets_1', 'norm1', 'bias'), ('mid_block', 'resnets_1', 'norm1', 'scale'), ('mid_block', 'resnets_1', 'norm2', 'bias'), ('mid_block', 'resnets_1', 'norm2', 'scale'), ('mid_block', 'resnets_1', 'time_emb_proj', 'bias'), ('mid_block', 'resnets_1', 'time_emb_proj', 'kernel'), ('time_embedding', 'linear_1', 'bias'), ('time_embedding', 'linear_1', 'kernel'), ('time_embedding', 'linear_2', 'bias'), ('time_embedding', 'linear_2', 'kernel'), ('up_blocks_0', 'resnets_0', 'conv1', 'bias'), ('up_blocks_0', 'resnets_0', 'conv1', 'kernel'), ('up_blocks_0', 'resnets_0', 'conv2', 'bias'), ('up_blocks_0', 'resnets_0', 'conv2', 'kernel'), ('up_blocks_0', 'resnets_0', 'conv_shortcut', 'bias'), ('up_blocks_0', 'resnets_0', 'conv_shortcut', 'kernel'), ('up_blocks_0', 'resnets_0', 'norm1', 'bias'), ('up_blocks_0', 'resnets_0', 'norm1', 'scale'), ('up_blocks_0', 'resnets_0', 'norm2', 'bias'), ('up_blocks_0', 'resnets_0', 'norm2', 'scale'), ('up_blocks_0', 'resnets_0', 'time_emb_proj', 'bias'), ('up_blocks_0', 'resnets_0', 'time_emb_proj', 'kernel'), ('up_blocks_0', 'resnets_1', 'conv1', 'bias'), ('up_blocks_0', 'resnets_1', 'conv1', 'kernel'), ('up_blocks_0', 'resnets_1', 'conv2', 'bias'), ('up_blocks_0', 'resnets_1', 'conv2', 'kernel'), ('up_blocks_0', 'resnets_1', 'conv_shortcut', 'bias'), ('up_blocks_0', 'resnets_1', 'conv_shortcut', 'kernel'), ('up_blocks_0', 'resnets_1', 'norm1', 'bias'), ('up_blocks_0', 'resnets_1', 'norm1', 'scale'), ('up_blocks_0', 'resnets_1', 'norm2', 'bias'), ('up_blocks_0', 'resnets_1', 'norm2', 'scale'), ('up_blocks_0', 'resnets_1', 'time_emb_proj', 'bias'), ('up_blocks_0', 'resnets_1', 'time_emb_proj', 'kernel'), ('up_blocks_0', 'resnets_2', 'conv1', 'bias'), ('up_blocks_0', 'resnets_2', 'conv1', 'kernel'), ('up_blocks_0', 'resnets_2', 'conv2', 'bias'), ('up_blocks_0', 'resnets_2', 'conv2', 'kernel'), ('up_blocks_0', 'resnets_2', 'conv_shortcut', 'bias'), ('up_blocks_0', 'resnets_2', 'conv_shortcut', 'kernel'), ('up_blocks_0', 'resnets_2', 'norm1', 'bias'), ('up_blocks_0', 'resnets_2', 'norm1', 'scale'), ('up_blocks_0', 'resnets_2', 'norm2', 'bias'), ('up_blocks_0', 'resnets_2', 'norm2', 'scale'), ('up_blocks_0', 'resnets_2', 'time_emb_proj', 'bias'), ('up_blocks_0', 'resnets_2', 'time_emb_proj', 'kernel'), ('up_blocks_0', 'upsamplers_0', 'conv', 'bias'), ('up_blocks_0', 'upsamplers_0', 'conv', 'kernel'), ('up_blocks_1', 'attentions_0', 'norm', 'bias'), ('up_blocks_1', 'attentions_0', 'norm', 'scale'), ('up_blocks_1', 'attentions_0', 'proj_in', 'bias'), ('up_blocks_1', 'attentions_0', 'proj_in', 'kernel'), ('up_blocks_1', 'attentions_0', 'proj_out', 'bias'), ('up_blocks_1', 'attentions_0', 'proj_out', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_1', 'attentions_1', 'norm', 'bias'), ('up_blocks_1', 'attentions_1', 'norm', 'scale'), ('up_blocks_1', 'attentions_1', 'proj_in', 'bias'), ('up_blocks_1', 'attentions_1', 'proj_in', 'kernel'), ('up_blocks_1', 'attentions_1', 'proj_out', 'bias'), ('up_blocks_1', 'attentions_1', 'proj_out', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_1', 'attentions_2', 'norm', 'bias'), ('up_blocks_1', 'attentions_2', 'norm', 'scale'), ('up_blocks_1', 'attentions_2', 'proj_in', 'bias'), ('up_blocks_1', 'attentions_2', 'proj_in', 'kernel'), ('up_blocks_1', 'attentions_2', 'proj_out', 'bias'), ('up_blocks_1', 'attentions_2', 'proj_out', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_1', 'resnets_0', 'conv1', 'bias'), ('up_blocks_1', 'resnets_0', 'conv1', 'kernel'), ('up_blocks_1', 'resnets_0', 'conv2', 'bias'), ('up_blocks_1', 'resnets_0', 'conv2', 'kernel'), ('up_blocks_1', 'resnets_0', 'conv_shortcut', 'bias'), ('up_blocks_1', 'resnets_0', 'conv_shortcut', 'kernel'), ('up_blocks_1', 'resnets_0', 'norm1', 'bias'), ('up_blocks_1', 'resnets_0', 'norm1', 'scale'), ('up_blocks_1', 'resnets_0', 'norm2', 'bias'), ('up_blocks_1', 'resnets_0', 'norm2', 'scale'), ('up_blocks_1', 'resnets_0', 'time_emb_proj', 'bias'), ('up_blocks_1', 'resnets_0', 'time_emb_proj', 'kernel'), ('up_blocks_1', 'resnets_1', 'conv1', 'bias'), ('up_blocks_1', 'resnets_1', 'conv1', 'kernel'), ('up_blocks_1', 'resnets_1', 'conv2', 'bias'), ('up_blocks_1', 'resnets_1', 'conv2', 'kernel'), ('up_blocks_1', 'resnets_1', 'conv_shortcut', 'bias'), ('up_blocks_1', 'resnets_1', 'conv_shortcut', 'kernel'), ('up_blocks_1', 'resnets_1', 'norm1', 'bias'), ('up_blocks_1', 'resnets_1', 'norm1', 'scale'), ('up_blocks_1', 'resnets_1', 'norm2', 'bias'), ('up_blocks_1', 'resnets_1', 'norm2', 'scale'), ('up_blocks_1', 'resnets_1', 'time_emb_proj', 'bias'), ('up_blocks_1', 'resnets_1', 'time_emb_proj', 'kernel'), ('up_blocks_1', 'resnets_2', 'conv1', 'bias'), ('up_blocks_1', 'resnets_2', 'conv1', 'kernel'), ('up_blocks_1', 'resnets_2', 'conv2', 'bias'), ('up_blocks_1', 'resnets_2', 'conv2', 'kernel'), ('up_blocks_1', 'resnets_2', 'conv_shortcut', 'bias'), ('up_blocks_1', 'resnets_2', 'conv_shortcut', 'kernel'), ('up_blocks_1', 'resnets_2', 'norm1', 'bias'), ('up_blocks_1', 'resnets_2', 'norm1', 'scale'), ('up_blocks_1', 'resnets_2', 'norm2', 'bias'), ('up_blocks_1', 'resnets_2', 'norm2', 'scale'), ('up_blocks_1', 'resnets_2', 'time_emb_proj', 'bias'), ('up_blocks_1', 'resnets_2', 'time_emb_proj', 'kernel'), ('up_blocks_1', 'upsamplers_0', 'conv', 'bias'), ('up_blocks_1', 'upsamplers_0', 'conv', 'kernel'), ('up_blocks_2', 'attentions_0', 'norm', 'bias'), ('up_blocks_2', 'attentions_0', 'norm', 'scale'), ('up_blocks_2', 'attentions_0', 'proj_in', 'bias'), ('up_blocks_2', 'attentions_0', 'proj_in', 'kernel'), ('up_blocks_2', 'attentions_0', 'proj_out', 'bias'), ('up_blocks_2', 'attentions_0', 'proj_out', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_2', 'attentions_1', 'norm', 'bias'), ('up_blocks_2', 'attentions_1', 'norm', 'scale'), ('up_blocks_2', 'attentions_1', 'proj_in', 'bias'), ('up_blocks_2', 'attentions_1', 'proj_in', 'kernel'), ('up_blocks_2', 'attentions_1', 'proj_out', 'bias'), ('up_blocks_2', 'attentions_1', 'proj_out', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_2', 'attentions_2', 'norm', 'bias'), ('up_blocks_2', 'attentions_2', 'norm', 'scale'), ('up_blocks_2', 'attentions_2', 'proj_in', 'bias'), ('up_blocks_2', 'attentions_2', 'proj_in', 'kernel'), ('up_blocks_2', 'attentions_2', 'proj_out', 'bias'), ('up_blocks_2', 'attentions_2', 'proj_out', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_2', 'resnets_0', 'conv1', 'bias'), ('up_blocks_2', 'resnets_0', 'conv1', 'kernel'), ('up_blocks_2', 'resnets_0', 'conv2', 'bias'), ('up_blocks_2', 'resnets_0', 'conv2', 'kernel'), ('up_blocks_2', 'resnets_0', 'conv_shortcut', 'bias'), ('up_blocks_2', 'resnets_0', 'conv_shortcut', 'kernel'), ('up_blocks_2', 'resnets_0', 'norm1', 'bias'), ('up_blocks_2', 'resnets_0', 'norm1', 'scale'), ('up_blocks_2', 'resnets_0', 'norm2', 'bias'), ('up_blocks_2', 'resnets_0', 'norm2', 'scale'), ('up_blocks_2', 'resnets_0', 'time_emb_proj', 'bias'), ('up_blocks_2', 'resnets_0', 'time_emb_proj', 'kernel'), ('up_blocks_2', 'resnets_1', 'conv1', 'bias'), ('up_blocks_2', 'resnets_1', 'conv1', 'kernel'), ('up_blocks_2', 'resnets_1', 'conv2', 'bias'), ('up_blocks_2', 'resnets_1', 'conv2', 'kernel'), ('up_blocks_2', 'resnets_1', 'conv_shortcut', 'bias'), ('up_blocks_2', 'resnets_1', 'conv_shortcut', 'kernel'), ('up_blocks_2', 'resnets_1', 'norm1', 'bias'), ('up_blocks_2', 'resnets_1', 'norm1', 'scale'), ('up_blocks_2', 'resnets_1', 'norm2', 'bias'), ('up_blocks_2', 'resnets_1', 'norm2', 'scale'), ('up_blocks_2', 'resnets_1', 'time_emb_proj', 'bias'), ('up_blocks_2', 'resnets_1', 'time_emb_proj', 'kernel'), ('up_blocks_2', 'resnets_2', 'conv1', 'bias'), ('up_blocks_2', 'resnets_2', 'conv1', 'kernel'), ('up_blocks_2', 'resnets_2', 'conv2', 'bias'), ('up_blocks_2', 'resnets_2', 'conv2', 'kernel'), ('up_blocks_2', 'resnets_2', 'conv_shortcut', 'bias'), ('up_blocks_2', 'resnets_2', 'conv_shortcut', 'kernel'), ('up_blocks_2', 'resnets_2', 'norm1', 'bias'), ('up_blocks_2', 'resnets_2', 'norm1', 'scale'), ('up_blocks_2', 'resnets_2', 'norm2', 'bias'), ('up_blocks_2', 'resnets_2', 'norm2', 'scale'), ('up_blocks_2', 'resnets_2', 'time_emb_proj', 'bias'), ('up_blocks_2', 'resnets_2', 'time_emb_proj', 'kernel'), ('up_blocks_2', 'upsamplers_0', 'conv', 'bias'), ('up_blocks_2', 'upsamplers_0', 'conv', 'kernel'), ('up_blocks_3', 'attentions_0', 'norm', 'bias'), ('up_blocks_3', 'attentions_0', 'norm', 'scale'), ('up_blocks_3', 'attentions_0', 'proj_in', 'bias'), ('up_blocks_3', 'attentions_0', 'proj_in', 'kernel'), ('up_blocks_3', 'attentions_0', 'proj_out', 'bias'), ('up_blocks_3', 'attentions_0', 'proj_out', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_3', 'attentions_1', 'norm', 'bias'), ('up_blocks_3', 'attentions_1', 'norm', 'scale'), ('up_blocks_3', 'attentions_1', 'proj_in', 'bias'), ('up_blocks_3', 'attentions_1', 'proj_in', 'kernel'), ('up_blocks_3', 'attentions_1', 'proj_out', 'bias'), ('up_blocks_3', 'attentions_1', 'proj_out', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_3', 'attentions_2', 'norm', 'bias'), ('up_blocks_3', 'attentions_2', 'norm', 'scale'), ('up_blocks_3', 'attentions_2', 'proj_in', 'bias'), ('up_blocks_3', 'attentions_2', 'proj_in', 'kernel'), ('up_blocks_3', 'attentions_2', 'proj_out', 'bias'), ('up_blocks_3', 'attentions_2', 'proj_out', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_3', 'resnets_0', 'conv1', 'bias'), ('up_blocks_3', 'resnets_0', 'conv1', 'kernel'), ('up_blocks_3', 'resnets_0', 'conv2', 'bias'), ('up_blocks_3', 'resnets_0', 'conv2', 'kernel'), ('up_blocks_3', 'resnets_0', 'conv_shortcut', 'bias'), ('up_blocks_3', 'resnets_0', 'conv_shortcut', 'kernel'), ('up_blocks_3', 'resnets_0', 'norm1', 'bias'), ('up_blocks_3', 'resnets_0', 'norm1', 'scale'), ('up_blocks_3', 'resnets_0', 'norm2', 'bias'), ('up_blocks_3', 'resnets_0', 'norm2', 'scale'), ('up_blocks_3', 'resnets_0', 'time_emb_proj', 'bias'), ('up_blocks_3', 'resnets_0', 'time_emb_proj', 'kernel'), ('up_blocks_3', 'resnets_1', 'conv1', 'bias'), ('up_blocks_3', 'resnets_1', 'conv1', 'kernel'), ('up_blocks_3', 'resnets_1', 'conv2', 'bias'), ('up_blocks_3', 'resnets_1', 'conv2', 'kernel'), ('up_blocks_3', 'resnets_1', 'conv_shortcut', 'bias'), ('up_blocks_3', 'resnets_1', 'conv_shortcut', 'kernel'), ('up_blocks_3', 'resnets_1', 'norm1', 'bias'), ('up_blocks_3', 'resnets_1', 'norm1', 'scale'), ('up_blocks_3', 'resnets_1', 'norm2', 'bias'), ('up_blocks_3', 'resnets_1', 'norm2', 'scale'), ('up_blocks_3', 'resnets_1', 'time_emb_proj', 'bias'), ('up_blocks_3', 'resnets_1', 'time_emb_proj', 'kernel'), ('up_blocks_3', 'resnets_2', 'conv1', 'bias'), ('up_blocks_3', 'resnets_2', 'conv1', 'kernel'), ('up_blocks_3', 'resnets_2', 'conv2', 'bias'), ('up_blocks_3', 'resnets_2', 'conv2', 'kernel'), ('up_blocks_3', 'resnets_2', 'conv_shortcut', 'bias'), ('up_blocks_3', 'resnets_2', 'conv_shortcut', 'kernel'), ('up_blocks_3', 'resnets_2', 'norm1', 'bias'), ('up_blocks_3', 'resnets_2', 'norm1', 'scale'), ('up_blocks_3', 'resnets_2', 'norm2', 'bias'), ('up_blocks_3', 'resnets_2', 'norm2', 'scale'), ('up_blocks_3', 'resnets_2', 'time_emb_proj', 'bias'), ('up_blocks_3', 'resnets_2', 'time_emb_proj', 'kernel')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~ModelMixin.to_fp32`] for further information on how to do this.
9.3 Inference
Since TPUs usually have 8 devices working in parallel, we’ll replicate our prompt as many times as devices we have. Then we’ll perform inference on the 8 devices at once, each responsible for generating one image. Thus, we’ll get 8 images in the same amount of time it takes for one chip to generate a single one.
After replicating the prompt, we obtain the tokenized text ids by invoking the prepare_inputs
function of the pipeline. The length of the tokenized text is set to 77 tokens, as required by the configuration of the underlying CLIP Text model.
= "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt = pipeline.prepare_inputs(prompt)
prompt_ids prompt_ids.shape
(8, 77)
9.3.1 Replication and parallelization
Model parameters and inputs have to be replicated across the 8 parallel devices we have. The parameters dictionary is replicated using flax.jax_utils.replicate
, which traverses the dictionary and changes the shape of the weights so they are repeated 8 times. Arrays are replicated using shard
.
= replicate(params) p_params
= shard(prompt_ids)
prompt_ids prompt_ids.shape
(8, 1, 77)
That shape means that each one of the 8
devices will receive as an input a jnp
array with shape (1, 77)
. 1
is therefore the batch size per device. In TPUs with sufficient memory, it could be larger than 1
if we wanted to generate multiple images (per chip) at once.
We are almost ready to generate images! We just need to create a random number generator to pass to the generation function. This is the standard procedure in Flax, which is very serious and opinionated about random numbers – all functions that deal with random numbers are expected to receive a generator. This ensures reproducibility, even when we are training across multiple distributed devices.
The helper function below uses a seed to initialize a random number generator. As long as we use the same seed, we’ll get the exact same results. Feel free to use different seeds when exploring results later in the notebook.
def create_key(seed=0):
return jax.random.PRNGKey(seed)
We obtain a rng and then “split” it 8 times so each device receives a different generator. Therefore, each device will create a different image, and the full process is reproducible.
= create_key(0)
rng = jax.random.split(rng, jax.device_count()) rng
JAX code can be compiled to an efficient representation that runs very fast. However, we need to ensure that all inputs have the same shape in subsequent calls; otherwise, JAX will have to recompile the code, and we wouldn’t be able to take advantage of the optimized speed.
The Flax pipeline can compile the code for us if we pass jit = True
as an argument. It will also ensure that the model runs in parallel in the 8 available devices.
The first time we run the following cell it will take a long time to compile, but subequent calls (even with different inputs) will be much faster. For example, it took more than a minute to compile in a TPU v2-8 when I tested, but then it takes about 7s
for future inference runs.
%%time
= pipeline(prompt_ids, p_params, rng, jit=True)[0] images
CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
Wall time: 1min 29s
The returned array has shape (8, 1, 512, 512, 3)
. We reshape it to get rid of the second dimension and obtain 8 images of 512 × 512 × 3
and then convert them to PIL.
= images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images) images
9.3.2 Visualization
Let’s create a helper function to display images in a grid.
def image_grid(imgs, rows, cols):
= imgs[0].size
w,h = Image.new('RGB', size=(cols*w, rows*h))
grid for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
2, 4) image_grid(images,
9.4 Using different prompts
We don’t have to replicate the same prompt in all the devices. We can do whatever we want: generate 2 prompts 4 times each, or even generate 8 different prompts at once. Let’s do that!
First, we’ll refactor the input preparation code into a handy function:
= [
prompts "Labrador in the style of Hokusai",
"Painting of a squirrel skating in New York",
"HAL-9000 in the style of Van Gogh",
"Times Square under water, with fish and a dolphin swimming around",
"Ancient Roman fresco showing a man working on his laptop",
"Close-up photograph of young black woman against urban background, high quality, bokeh",
"Armchair in the shape of an avocado",
"Clown astronaut in space, with Earth in the background"
]
= pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
prompt_ids
= pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
images
2, 4) image_grid(images,
9.5 How does parallelization work?
We said before that the diffusers
Flax pipeline automatically compiles the model and runs it in parallel on all available devices. We’ll now briefly look inside that process to show how it works.
JAX parallelization can be done in multiple ways. The easiest one revolves around using the jax.pmap
function to achieve single-program, multiple-data (SPMD) parallelization. It means we’ll run several copies of the same code, each on different data inputs. More sophisticated approaches are possible, we invite you to go over the JAX documentation and the pjit
pages to explore this topic if you are interested!
jax.pmap
does two things for us: - Compiles (or jit
s) the code, as if we had invoked jax.jit()
. This does not happen when we call pmap
, but the first time the pmapped function is invoked. - Ensures the compiled code runs in parallel in all the available devices.
To show how it works we pmap
the _generate
method of the pipeline, which is the private method that runs generates images. Please, note that this method may be renamed or removed in future releases of diffusers
.
= pmap(pipeline._generate) p_generate
After we use pmap
, the prepared function p_generate
will conceptually do the following: * Invoke a copy of the underlying function pipeline._generate
in each device. * Send each device a different portion of the input arguments. That’s what sharding is used for. In our case, prompt_ids
has shape (8, 1, 77, 768)
. This array will be split in 8
and each copy of _generate
will receive an input with shape (1, 77, 768)
.
We can code _generate
completely ignoring the fact that it will be invoked in parallel. We just care about our batch size (1
in this example) and the dimensions that make sense for our code, and don’t have to change anything to make it work in parallel.
The same way as when we used the pipeline call, the first time we run the following cell it will take a while, but then it will be much faster.
%%time
= p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images images.shape
CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
Wall time: 1min 15s
(8, 1, 512, 512, 3)
images.shape
(8, 1, 512, 512, 3)
We use block_until_ready()
to correctly measure inference time, because JAX uses asynchronous dispatch and returns control to the Python loop as soon as it can. You don’t need to use that in your code; blocking will occur automatically when you want to use the result of a computation that has not yet been materialized.