9  Stable Diffusion in JAX / Flax !

Grateful to share this notebook from Hugging Face. Related blog post here.

Open In Colab

🤗 Hugging Face Diffusers supports Flax since version 0.5.1! This allows for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform.

This notebook shows how to run inference using JAX / Flax. If you want more details about how Stable Diffusion works or want to run it in GPU, please refer to this Colab notebook.

First, make sure you are using a TPU backend. If you are running this notebook in Colab, select Runtime in the menu above, then select the option “Change runtime type” and then select TPU under the Hardware accelerator setting.

Note that JAX is not exclusive to TPUs, but it shines on that hardware because each TPU server has 8 TPU accelerators working in parallel.

9.1 Setup

!pip install flax transformers ftfy
!pip install diffusers==0.9.0
import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")
assert "TPU" in device_type, "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
Found 8 JAX devices of type Cloud TPU.

Then we import all the dependencies.

import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline

9.2 Model Loading

Before using the model, you need to accept the model license in order to download and use the weights.

The license is designed to mitigate the potential harmful effects of such a powerful machine learning system. We request users to read the license entirely and carefully. Here we offer a summary:

  1. You can’t use the model to deliberately produce nor share illegal or harmful outputs or content,

  2. We claim no rights on the outputs you generate, you are free to use them and are accountable for their use which should not go against the provisions set in the license, and

  3. You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users.

Flax weights are available in Hugging Face Hub as part of the Stable Diffusion repo. To use them, you need to be a registered user in Hugging Face Hub and use an access token for the code to work. You have two options to provide your access token:

  • Use the huggingface-cli login command-line tool in your terminal and paste your token when prompted. It will be saved in a file in your computer.
  • Or use notebook_login() in a notebook, which does the same thing.

The following cell will present a login interface unless you’ve already authenticated before in this computer. You’ll need to paste your access token.

if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()
Login successful
Your token has been saved to /root/.huggingface/token

TPU devices support bfloat16, an efficient half-float type. We’ll use it for our tests, but you can also use float32 to use full precision instead.

dtype = jnp.bfloat16

Flax is a functional framework, so models are stateless and parameters are stored outside them. Loading the pre-trained Flax pipeline will return both the pipeline itself and the model weights (or parameters). We are using a bf16 version of the weights, which leads to type warnings that you can safely ignore.

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=dtype,
)
Some of the weights of FlaxStableDiffusionSafetyChecker were initialized in bfloat16 precision from the model checkpoint at /root/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/safety_checker:
[('concept_embeds',), ('concept_embeds_weights',), ('special_care_embeds',), ('special_care_embeds_weights',), ('vision_model', 'vision_model', 'embeddings', 'class_embedding'), ('vision_model', 'vision_model', 'embeddings', 'patch_embedding', 'kernel'), ('vision_model', 'vision_model', 'embeddings', 'position_embedding', 'embedding'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'post_layernorm', 'bias'), ('vision_model', 'vision_model', 'post_layernorm', 'scale'), ('vision_model', 'vision_model', 'pre_layrnorm', 'bias'), ('vision_model', 'vision_model', 'pre_layrnorm', 'scale'), ('visual_projection', 'kernel')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.
Some of the weights of FlaxCLIPTextModel were initialized in bfloat16 precision from the model checkpoint at /root/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/text_encoder:
[('text_model', 'embeddings', 'position_embedding', 'embedding'), ('text_model', 'embeddings', 'token_embedding', 'embedding'), ('text_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'final_layer_norm', 'bias'), ('text_model', 'final_layer_norm', 'scale')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.
Some of the weights of FlaxAutoencoderKL were initialized in bfloat16 precision from the model checkpoint at /root/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/vae:
[('decoder', 'conv_in', 'bias'), ('decoder', 'conv_in', 'kernel'), ('decoder', 'conv_norm_out', 'bias'), ('decoder', 'conv_norm_out', 'scale'), ('decoder', 'conv_out', 'bias'), ('decoder', 'conv_out', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'group_norm', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'group_norm', 'scale'), ('decoder', 'mid_block', 'attentions_0', 'key', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'key', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'proj_attn', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'proj_attn', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'query', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'query', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'value', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'value', 'kernel'), ('decoder', 'mid_block', 'resnets_0', 'conv1', 'bias'), ('decoder', 'mid_block', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'mid_block', 'resnets_0', 'conv2', 'bias'), ('decoder', 'mid_block', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'mid_block', 'resnets_0', 'norm1', 'bias'), ('decoder', 'mid_block', 'resnets_0', 'norm1', 'scale'), ('decoder', 'mid_block', 'resnets_0', 'norm2', 'bias'), ('decoder', 'mid_block', 'resnets_0', 'norm2', 'scale'), ('decoder', 'mid_block', 'resnets_1', 'conv1', 'bias'), ('decoder', 'mid_block', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'mid_block', 'resnets_1', 'conv2', 'bias'), ('decoder', 'mid_block', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'mid_block', 'resnets_1', 'norm1', 'bias'), ('decoder', 'mid_block', 'resnets_1', 'norm1', 'scale'), ('decoder', 'mid_block', 'resnets_1', 'norm2', 'bias'), ('decoder', 'mid_block', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_0', 'resnets_0', 'conv1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_0', 'conv2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_0', 'norm1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_0', 'norm1', 'scale'), ('decoder', 'up_blocks_0', 'resnets_0', 'norm2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_0', 'norm2', 'scale'), ('decoder', 'up_blocks_0', 'resnets_1', 'conv1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_1', 'conv2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_1', 'norm1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_1', 'norm1', 'scale'), ('decoder', 'up_blocks_0', 'resnets_1', 'norm2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_0', 'resnets_2', 'conv1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_2', 'conv1', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_2', 'conv2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_2', 'conv2', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_2', 'norm1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_2', 'norm1', 'scale'), ('decoder', 'up_blocks_0', 'resnets_2', 'norm2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_2', 'norm2', 'scale'), ('decoder', 'up_blocks_0', 'upsamplers_0', 'conv', 'bias'), ('decoder', 'up_blocks_0', 'upsamplers_0', 'conv', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_0', 'conv1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_0', 'conv2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_0', 'norm1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_0', 'norm1', 'scale'), ('decoder', 'up_blocks_1', 'resnets_0', 'norm2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_0', 'norm2', 'scale'), ('decoder', 'up_blocks_1', 'resnets_1', 'conv1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_1', 'conv2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_1', 'norm1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_1', 'norm1', 'scale'), ('decoder', 'up_blocks_1', 'resnets_1', 'norm2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_1', 'resnets_2', 'conv1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_2', 'conv1', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_2', 'conv2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_2', 'conv2', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_2', 'norm1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_2', 'norm1', 'scale'), ('decoder', 'up_blocks_1', 'resnets_2', 'norm2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_2', 'norm2', 'scale'), ('decoder', 'up_blocks_1', 'upsamplers_0', 'conv', 'bias'), ('decoder', 'up_blocks_1', 'upsamplers_0', 'conv', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv_shortcut', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv_shortcut', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_0', 'norm1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'norm1', 'scale'), ('decoder', 'up_blocks_2', 'resnets_0', 'norm2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'norm2', 'scale'), ('decoder', 'up_blocks_2', 'resnets_1', 'conv1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_1', 'conv2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_1', 'norm1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_1', 'norm1', 'scale'), ('decoder', 'up_blocks_2', 'resnets_1', 'norm2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_2', 'resnets_2', 'conv1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_2', 'conv1', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_2', 'conv2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_2', 'conv2', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_2', 'norm1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_2', 'norm1', 'scale'), ('decoder', 'up_blocks_2', 'resnets_2', 'norm2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_2', 'norm2', 'scale'), ('decoder', 'up_blocks_2', 'upsamplers_0', 'conv', 'bias'), ('decoder', 'up_blocks_2', 'upsamplers_0', 'conv', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv_shortcut', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv_shortcut', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_0', 'norm1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'norm1', 'scale'), ('decoder', 'up_blocks_3', 'resnets_0', 'norm2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'norm2', 'scale'), ('decoder', 'up_blocks_3', 'resnets_1', 'conv1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_1', 'conv2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_1', 'norm1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_1', 'norm1', 'scale'), ('decoder', 'up_blocks_3', 'resnets_1', 'norm2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_3', 'resnets_2', 'conv1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_2', 'conv1', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_2', 'conv2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_2', 'conv2', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_2', 'norm1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_2', 'norm1', 'scale'), ('decoder', 'up_blocks_3', 'resnets_2', 'norm2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_2', 'norm2', 'scale'), ('encoder', 'conv_in', 'bias'), ('encoder', 'conv_in', 'kernel'), ('encoder', 'conv_norm_out', 'bias'), ('encoder', 'conv_norm_out', 'scale'), ('encoder', 'conv_out', 'bias'), ('encoder', 'conv_out', 'kernel'), ('encoder', 'down_blocks_0', 'downsamplers_0', 'conv', 'bias'), ('encoder', 'down_blocks_0', 'downsamplers_0', 'conv', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_0', 'conv1', 'bias'), ('encoder', 'down_blocks_0', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_0', 'conv2', 'bias'), ('encoder', 'down_blocks_0', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_0', 'norm1', 'bias'), ('encoder', 'down_blocks_0', 'resnets_0', 'norm1', 'scale'), ('encoder', 'down_blocks_0', 'resnets_0', 'norm2', 'bias'), ('encoder', 'down_blocks_0', 'resnets_0', 'norm2', 'scale'), ('encoder', 'down_blocks_0', 'resnets_1', 'conv1', 'bias'), ('encoder', 'down_blocks_0', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_1', 'conv2', 'bias'), ('encoder', 'down_blocks_0', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_1', 'norm1', 'bias'), ('encoder', 'down_blocks_0', 'resnets_1', 'norm1', 'scale'), ('encoder', 'down_blocks_0', 'resnets_1', 'norm2', 'bias'), ('encoder', 'down_blocks_0', 'resnets_1', 'norm2', 'scale'), ('encoder', 'down_blocks_1', 'downsamplers_0', 'conv', 'bias'), ('encoder', 'down_blocks_1', 'downsamplers_0', 'conv', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv1', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv2', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv_shortcut', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv_shortcut', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_0', 'norm1', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'norm1', 'scale'), ('encoder', 'down_blocks_1', 'resnets_0', 'norm2', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'norm2', 'scale'), ('encoder', 'down_blocks_1', 'resnets_1', 'conv1', 'bias'), ('encoder', 'down_blocks_1', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_1', 'conv2', 'bias'), ('encoder', 'down_blocks_1', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_1', 'norm1', 'bias'), ('encoder', 'down_blocks_1', 'resnets_1', 'norm1', 'scale'), ('encoder', 'down_blocks_1', 'resnets_1', 'norm2', 'bias'), ('encoder', 'down_blocks_1', 'resnets_1', 'norm2', 'scale'), ('encoder', 'down_blocks_2', 'downsamplers_0', 'conv', 'bias'), ('encoder', 'down_blocks_2', 'downsamplers_0', 'conv', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv1', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv2', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv_shortcut', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv_shortcut', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_0', 'norm1', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'norm1', 'scale'), ('encoder', 'down_blocks_2', 'resnets_0', 'norm2', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'norm2', 'scale'), ('encoder', 'down_blocks_2', 'resnets_1', 'conv1', 'bias'), ('encoder', 'down_blocks_2', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_1', 'conv2', 'bias'), ('encoder', 'down_blocks_2', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_1', 'norm1', 'bias'), ('encoder', 'down_blocks_2', 'resnets_1', 'norm1', 'scale'), ('encoder', 'down_blocks_2', 'resnets_1', 'norm2', 'bias'), ('encoder', 'down_blocks_2', 'resnets_1', 'norm2', 'scale'), ('encoder', 'down_blocks_3', 'resnets_0', 'conv1', 'bias'), ('encoder', 'down_blocks_3', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'down_blocks_3', 'resnets_0', 'conv2', 'bias'), ('encoder', 'down_blocks_3', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'down_blocks_3', 'resnets_0', 'norm1', 'bias'), ('encoder', 'down_blocks_3', 'resnets_0', 'norm1', 'scale'), ('encoder', 'down_blocks_3', 'resnets_0', 'norm2', 'bias'), ('encoder', 'down_blocks_3', 'resnets_0', 'norm2', 'scale'), ('encoder', 'down_blocks_3', 'resnets_1', 'conv1', 'bias'), ('encoder', 'down_blocks_3', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'down_blocks_3', 'resnets_1', 'conv2', 'bias'), ('encoder', 'down_blocks_3', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'down_blocks_3', 'resnets_1', 'norm1', 'bias'), ('encoder', 'down_blocks_3', 'resnets_1', 'norm1', 'scale'), ('encoder', 'down_blocks_3', 'resnets_1', 'norm2', 'bias'), ('encoder', 'down_blocks_3', 'resnets_1', 'norm2', 'scale'), ('encoder', 'mid_block', 'attentions_0', 'group_norm', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'group_norm', 'scale'), ('encoder', 'mid_block', 'attentions_0', 'key', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'key', 'kernel'), ('encoder', 'mid_block', 'attentions_0', 'proj_attn', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'proj_attn', 'kernel'), ('encoder', 'mid_block', 'attentions_0', 'query', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'query', 'kernel'), ('encoder', 'mid_block', 'attentions_0', 'value', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'value', 'kernel'), ('encoder', 'mid_block', 'resnets_0', 'conv1', 'bias'), ('encoder', 'mid_block', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'mid_block', 'resnets_0', 'conv2', 'bias'), ('encoder', 'mid_block', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'mid_block', 'resnets_0', 'norm1', 'bias'), ('encoder', 'mid_block', 'resnets_0', 'norm1', 'scale'), ('encoder', 'mid_block', 'resnets_0', 'norm2', 'bias'), ('encoder', 'mid_block', 'resnets_0', 'norm2', 'scale'), ('encoder', 'mid_block', 'resnets_1', 'conv1', 'bias'), ('encoder', 'mid_block', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'mid_block', 'resnets_1', 'conv2', 'bias'), ('encoder', 'mid_block', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'mid_block', 'resnets_1', 'norm1', 'bias'), ('encoder', 'mid_block', 'resnets_1', 'norm1', 'scale'), ('encoder', 'mid_block', 'resnets_1', 'norm2', 'bias'), ('encoder', 'mid_block', 'resnets_1', 'norm2', 'scale'), ('post_quant_conv', 'bias'), ('post_quant_conv', 'kernel'), ('quant_conv', 'bias'), ('quant_conv', 'kernel')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~ModelMixin.to_fp32`] for further information on how to do this.
Some of the weights of FlaxUNet2DConditionModel were initialized in bfloat16 precision from the model checkpoint at /root/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/unet:
[('conv_in', 'bias'), ('conv_in', 'kernel'), ('conv_norm_out', 'bias'), ('conv_norm_out', 'scale'), ('conv_out', 'bias'), ('conv_out', 'kernel'), ('down_blocks_0', 'attentions_0', 'norm', 'bias'), ('down_blocks_0', 'attentions_0', 'norm', 'scale'), ('down_blocks_0', 'attentions_0', 'proj_in', 'bias'), ('down_blocks_0', 'attentions_0', 'proj_in', 'kernel'), ('down_blocks_0', 'attentions_0', 'proj_out', 'bias'), ('down_blocks_0', 'attentions_0', 'proj_out', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_0', 'attentions_1', 'norm', 'bias'), ('down_blocks_0', 'attentions_1', 'norm', 'scale'), ('down_blocks_0', 'attentions_1', 'proj_in', 'bias'), ('down_blocks_0', 'attentions_1', 'proj_in', 'kernel'), ('down_blocks_0', 'attentions_1', 'proj_out', 'bias'), ('down_blocks_0', 'attentions_1', 'proj_out', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_0', 'downsamplers_0', 'conv', 'bias'), ('down_blocks_0', 'downsamplers_0', 'conv', 'kernel'), ('down_blocks_0', 'resnets_0', 'conv1', 'bias'), ('down_blocks_0', 'resnets_0', 'conv1', 'kernel'), ('down_blocks_0', 'resnets_0', 'conv2', 'bias'), ('down_blocks_0', 'resnets_0', 'conv2', 'kernel'), ('down_blocks_0', 'resnets_0', 'norm1', 'bias'), ('down_blocks_0', 'resnets_0', 'norm1', 'scale'), ('down_blocks_0', 'resnets_0', 'norm2', 'bias'), ('down_blocks_0', 'resnets_0', 'norm2', 'scale'), ('down_blocks_0', 'resnets_0', 'time_emb_proj', 'bias'), ('down_blocks_0', 'resnets_0', 'time_emb_proj', 'kernel'), ('down_blocks_0', 'resnets_1', 'conv1', 'bias'), ('down_blocks_0', 'resnets_1', 'conv1', 'kernel'), ('down_blocks_0', 'resnets_1', 'conv2', 'bias'), ('down_blocks_0', 'resnets_1', 'conv2', 'kernel'), ('down_blocks_0', 'resnets_1', 'norm1', 'bias'), ('down_blocks_0', 'resnets_1', 'norm1', 'scale'), ('down_blocks_0', 'resnets_1', 'norm2', 'bias'), ('down_blocks_0', 'resnets_1', 'norm2', 'scale'), ('down_blocks_0', 'resnets_1', 'time_emb_proj', 'bias'), ('down_blocks_0', 'resnets_1', 'time_emb_proj', 'kernel'), ('down_blocks_1', 'attentions_0', 'norm', 'bias'), ('down_blocks_1', 'attentions_0', 'norm', 'scale'), ('down_blocks_1', 'attentions_0', 'proj_in', 'bias'), ('down_blocks_1', 'attentions_0', 'proj_in', 'kernel'), ('down_blocks_1', 'attentions_0', 'proj_out', 'bias'), ('down_blocks_1', 'attentions_0', 'proj_out', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_1', 'attentions_1', 'norm', 'bias'), ('down_blocks_1', 'attentions_1', 'norm', 'scale'), ('down_blocks_1', 'attentions_1', 'proj_in', 'bias'), ('down_blocks_1', 'attentions_1', 'proj_in', 'kernel'), ('down_blocks_1', 'attentions_1', 'proj_out', 'bias'), ('down_blocks_1', 'attentions_1', 'proj_out', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_1', 'downsamplers_0', 'conv', 'bias'), ('down_blocks_1', 'downsamplers_0', 'conv', 'kernel'), ('down_blocks_1', 'resnets_0', 'conv1', 'bias'), ('down_blocks_1', 'resnets_0', 'conv1', 'kernel'), ('down_blocks_1', 'resnets_0', 'conv2', 'bias'), ('down_blocks_1', 'resnets_0', 'conv2', 'kernel'), ('down_blocks_1', 'resnets_0', 'conv_shortcut', 'bias'), ('down_blocks_1', 'resnets_0', 'conv_shortcut', 'kernel'), ('down_blocks_1', 'resnets_0', 'norm1', 'bias'), ('down_blocks_1', 'resnets_0', 'norm1', 'scale'), ('down_blocks_1', 'resnets_0', 'norm2', 'bias'), ('down_blocks_1', 'resnets_0', 'norm2', 'scale'), ('down_blocks_1', 'resnets_0', 'time_emb_proj', 'bias'), ('down_blocks_1', 'resnets_0', 'time_emb_proj', 'kernel'), ('down_blocks_1', 'resnets_1', 'conv1', 'bias'), ('down_blocks_1', 'resnets_1', 'conv1', 'kernel'), ('down_blocks_1', 'resnets_1', 'conv2', 'bias'), ('down_blocks_1', 'resnets_1', 'conv2', 'kernel'), ('down_blocks_1', 'resnets_1', 'norm1', 'bias'), ('down_blocks_1', 'resnets_1', 'norm1', 'scale'), ('down_blocks_1', 'resnets_1', 'norm2', 'bias'), ('down_blocks_1', 'resnets_1', 'norm2', 'scale'), ('down_blocks_1', 'resnets_1', 'time_emb_proj', 'bias'), ('down_blocks_1', 'resnets_1', 'time_emb_proj', 'kernel'), ('down_blocks_2', 'attentions_0', 'norm', 'bias'), ('down_blocks_2', 'attentions_0', 'norm', 'scale'), ('down_blocks_2', 'attentions_0', 'proj_in', 'bias'), ('down_blocks_2', 'attentions_0', 'proj_in', 'kernel'), ('down_blocks_2', 'attentions_0', 'proj_out', 'bias'), ('down_blocks_2', 'attentions_0', 'proj_out', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_2', 'attentions_1', 'norm', 'bias'), ('down_blocks_2', 'attentions_1', 'norm', 'scale'), ('down_blocks_2', 'attentions_1', 'proj_in', 'bias'), ('down_blocks_2', 'attentions_1', 'proj_in', 'kernel'), ('down_blocks_2', 'attentions_1', 'proj_out', 'bias'), ('down_blocks_2', 'attentions_1', 'proj_out', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_2', 'downsamplers_0', 'conv', 'bias'), ('down_blocks_2', 'downsamplers_0', 'conv', 'kernel'), ('down_blocks_2', 'resnets_0', 'conv1', 'bias'), ('down_blocks_2', 'resnets_0', 'conv1', 'kernel'), ('down_blocks_2', 'resnets_0', 'conv2', 'bias'), ('down_blocks_2', 'resnets_0', 'conv2', 'kernel'), ('down_blocks_2', 'resnets_0', 'conv_shortcut', 'bias'), ('down_blocks_2', 'resnets_0', 'conv_shortcut', 'kernel'), ('down_blocks_2', 'resnets_0', 'norm1', 'bias'), ('down_blocks_2', 'resnets_0', 'norm1', 'scale'), ('down_blocks_2', 'resnets_0', 'norm2', 'bias'), ('down_blocks_2', 'resnets_0', 'norm2', 'scale'), ('down_blocks_2', 'resnets_0', 'time_emb_proj', 'bias'), ('down_blocks_2', 'resnets_0', 'time_emb_proj', 'kernel'), ('down_blocks_2', 'resnets_1', 'conv1', 'bias'), ('down_blocks_2', 'resnets_1', 'conv1', 'kernel'), ('down_blocks_2', 'resnets_1', 'conv2', 'bias'), ('down_blocks_2', 'resnets_1', 'conv2', 'kernel'), ('down_blocks_2', 'resnets_1', 'norm1', 'bias'), ('down_blocks_2', 'resnets_1', 'norm1', 'scale'), ('down_blocks_2', 'resnets_1', 'norm2', 'bias'), ('down_blocks_2', 'resnets_1', 'norm2', 'scale'), ('down_blocks_2', 'resnets_1', 'time_emb_proj', 'bias'), ('down_blocks_2', 'resnets_1', 'time_emb_proj', 'kernel'), ('down_blocks_3', 'resnets_0', 'conv1', 'bias'), ('down_blocks_3', 'resnets_0', 'conv1', 'kernel'), ('down_blocks_3', 'resnets_0', 'conv2', 'bias'), ('down_blocks_3', 'resnets_0', 'conv2', 'kernel'), ('down_blocks_3', 'resnets_0', 'norm1', 'bias'), ('down_blocks_3', 'resnets_0', 'norm1', 'scale'), ('down_blocks_3', 'resnets_0', 'norm2', 'bias'), ('down_blocks_3', 'resnets_0', 'norm2', 'scale'), ('down_blocks_3', 'resnets_0', 'time_emb_proj', 'bias'), ('down_blocks_3', 'resnets_0', 'time_emb_proj', 'kernel'), ('down_blocks_3', 'resnets_1', 'conv1', 'bias'), ('down_blocks_3', 'resnets_1', 'conv1', 'kernel'), ('down_blocks_3', 'resnets_1', 'conv2', 'bias'), ('down_blocks_3', 'resnets_1', 'conv2', 'kernel'), ('down_blocks_3', 'resnets_1', 'norm1', 'bias'), ('down_blocks_3', 'resnets_1', 'norm1', 'scale'), ('down_blocks_3', 'resnets_1', 'norm2', 'bias'), ('down_blocks_3', 'resnets_1', 'norm2', 'scale'), ('down_blocks_3', 'resnets_1', 'time_emb_proj', 'bias'), ('down_blocks_3', 'resnets_1', 'time_emb_proj', 'kernel'), ('mid_block', 'attentions_0', 'norm', 'bias'), ('mid_block', 'attentions_0', 'norm', 'scale'), ('mid_block', 'attentions_0', 'proj_in', 'bias'), ('mid_block', 'attentions_0', 'proj_in', 'kernel'), ('mid_block', 'attentions_0', 'proj_out', 'bias'), ('mid_block', 'attentions_0', 'proj_out', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('mid_block', 'resnets_0', 'conv1', 'bias'), ('mid_block', 'resnets_0', 'conv1', 'kernel'), ('mid_block', 'resnets_0', 'conv2', 'bias'), ('mid_block', 'resnets_0', 'conv2', 'kernel'), ('mid_block', 'resnets_0', 'norm1', 'bias'), ('mid_block', 'resnets_0', 'norm1', 'scale'), ('mid_block', 'resnets_0', 'norm2', 'bias'), ('mid_block', 'resnets_0', 'norm2', 'scale'), ('mid_block', 'resnets_0', 'time_emb_proj', 'bias'), ('mid_block', 'resnets_0', 'time_emb_proj', 'kernel'), ('mid_block', 'resnets_1', 'conv1', 'bias'), ('mid_block', 'resnets_1', 'conv1', 'kernel'), ('mid_block', 'resnets_1', 'conv2', 'bias'), ('mid_block', 'resnets_1', 'conv2', 'kernel'), ('mid_block', 'resnets_1', 'norm1', 'bias'), ('mid_block', 'resnets_1', 'norm1', 'scale'), ('mid_block', 'resnets_1', 'norm2', 'bias'), ('mid_block', 'resnets_1', 'norm2', 'scale'), ('mid_block', 'resnets_1', 'time_emb_proj', 'bias'), ('mid_block', 'resnets_1', 'time_emb_proj', 'kernel'), ('time_embedding', 'linear_1', 'bias'), ('time_embedding', 'linear_1', 'kernel'), ('time_embedding', 'linear_2', 'bias'), ('time_embedding', 'linear_2', 'kernel'), ('up_blocks_0', 'resnets_0', 'conv1', 'bias'), ('up_blocks_0', 'resnets_0', 'conv1', 'kernel'), ('up_blocks_0', 'resnets_0', 'conv2', 'bias'), ('up_blocks_0', 'resnets_0', 'conv2', 'kernel'), ('up_blocks_0', 'resnets_0', 'conv_shortcut', 'bias'), ('up_blocks_0', 'resnets_0', 'conv_shortcut', 'kernel'), ('up_blocks_0', 'resnets_0', 'norm1', 'bias'), ('up_blocks_0', 'resnets_0', 'norm1', 'scale'), ('up_blocks_0', 'resnets_0', 'norm2', 'bias'), ('up_blocks_0', 'resnets_0', 'norm2', 'scale'), ('up_blocks_0', 'resnets_0', 'time_emb_proj', 'bias'), ('up_blocks_0', 'resnets_0', 'time_emb_proj', 'kernel'), ('up_blocks_0', 'resnets_1', 'conv1', 'bias'), ('up_blocks_0', 'resnets_1', 'conv1', 'kernel'), ('up_blocks_0', 'resnets_1', 'conv2', 'bias'), ('up_blocks_0', 'resnets_1', 'conv2', 'kernel'), ('up_blocks_0', 'resnets_1', 'conv_shortcut', 'bias'), ('up_blocks_0', 'resnets_1', 'conv_shortcut', 'kernel'), ('up_blocks_0', 'resnets_1', 'norm1', 'bias'), ('up_blocks_0', 'resnets_1', 'norm1', 'scale'), ('up_blocks_0', 'resnets_1', 'norm2', 'bias'), ('up_blocks_0', 'resnets_1', 'norm2', 'scale'), ('up_blocks_0', 'resnets_1', 'time_emb_proj', 'bias'), ('up_blocks_0', 'resnets_1', 'time_emb_proj', 'kernel'), ('up_blocks_0', 'resnets_2', 'conv1', 'bias'), ('up_blocks_0', 'resnets_2', 'conv1', 'kernel'), ('up_blocks_0', 'resnets_2', 'conv2', 'bias'), ('up_blocks_0', 'resnets_2', 'conv2', 'kernel'), ('up_blocks_0', 'resnets_2', 'conv_shortcut', 'bias'), ('up_blocks_0', 'resnets_2', 'conv_shortcut', 'kernel'), ('up_blocks_0', 'resnets_2', 'norm1', 'bias'), ('up_blocks_0', 'resnets_2', 'norm1', 'scale'), ('up_blocks_0', 'resnets_2', 'norm2', 'bias'), ('up_blocks_0', 'resnets_2', 'norm2', 'scale'), ('up_blocks_0', 'resnets_2', 'time_emb_proj', 'bias'), ('up_blocks_0', 'resnets_2', 'time_emb_proj', 'kernel'), ('up_blocks_0', 'upsamplers_0', 'conv', 'bias'), ('up_blocks_0', 'upsamplers_0', 'conv', 'kernel'), ('up_blocks_1', 'attentions_0', 'norm', 'bias'), ('up_blocks_1', 'attentions_0', 'norm', 'scale'), ('up_blocks_1', 'attentions_0', 'proj_in', 'bias'), ('up_blocks_1', 'attentions_0', 'proj_in', 'kernel'), ('up_blocks_1', 'attentions_0', 'proj_out', 'bias'), ('up_blocks_1', 'attentions_0', 'proj_out', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_1', 'attentions_1', 'norm', 'bias'), ('up_blocks_1', 'attentions_1', 'norm', 'scale'), ('up_blocks_1', 'attentions_1', 'proj_in', 'bias'), ('up_blocks_1', 'attentions_1', 'proj_in', 'kernel'), ('up_blocks_1', 'attentions_1', 'proj_out', 'bias'), ('up_blocks_1', 'attentions_1', 'proj_out', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_1', 'attentions_2', 'norm', 'bias'), ('up_blocks_1', 'attentions_2', 'norm', 'scale'), ('up_blocks_1', 'attentions_2', 'proj_in', 'bias'), ('up_blocks_1', 'attentions_2', 'proj_in', 'kernel'), ('up_blocks_1', 'attentions_2', 'proj_out', 'bias'), ('up_blocks_1', 'attentions_2', 'proj_out', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_1', 'resnets_0', 'conv1', 'bias'), ('up_blocks_1', 'resnets_0', 'conv1', 'kernel'), ('up_blocks_1', 'resnets_0', 'conv2', 'bias'), ('up_blocks_1', 'resnets_0', 'conv2', 'kernel'), ('up_blocks_1', 'resnets_0', 'conv_shortcut', 'bias'), ('up_blocks_1', 'resnets_0', 'conv_shortcut', 'kernel'), ('up_blocks_1', 'resnets_0', 'norm1', 'bias'), ('up_blocks_1', 'resnets_0', 'norm1', 'scale'), ('up_blocks_1', 'resnets_0', 'norm2', 'bias'), ('up_blocks_1', 'resnets_0', 'norm2', 'scale'), ('up_blocks_1', 'resnets_0', 'time_emb_proj', 'bias'), ('up_blocks_1', 'resnets_0', 'time_emb_proj', 'kernel'), ('up_blocks_1', 'resnets_1', 'conv1', 'bias'), ('up_blocks_1', 'resnets_1', 'conv1', 'kernel'), ('up_blocks_1', 'resnets_1', 'conv2', 'bias'), ('up_blocks_1', 'resnets_1', 'conv2', 'kernel'), ('up_blocks_1', 'resnets_1', 'conv_shortcut', 'bias'), ('up_blocks_1', 'resnets_1', 'conv_shortcut', 'kernel'), ('up_blocks_1', 'resnets_1', 'norm1', 'bias'), ('up_blocks_1', 'resnets_1', 'norm1', 'scale'), ('up_blocks_1', 'resnets_1', 'norm2', 'bias'), ('up_blocks_1', 'resnets_1', 'norm2', 'scale'), ('up_blocks_1', 'resnets_1', 'time_emb_proj', 'bias'), ('up_blocks_1', 'resnets_1', 'time_emb_proj', 'kernel'), ('up_blocks_1', 'resnets_2', 'conv1', 'bias'), ('up_blocks_1', 'resnets_2', 'conv1', 'kernel'), ('up_blocks_1', 'resnets_2', 'conv2', 'bias'), ('up_blocks_1', 'resnets_2', 'conv2', 'kernel'), ('up_blocks_1', 'resnets_2', 'conv_shortcut', 'bias'), ('up_blocks_1', 'resnets_2', 'conv_shortcut', 'kernel'), ('up_blocks_1', 'resnets_2', 'norm1', 'bias'), ('up_blocks_1', 'resnets_2', 'norm1', 'scale'), ('up_blocks_1', 'resnets_2', 'norm2', 'bias'), ('up_blocks_1', 'resnets_2', 'norm2', 'scale'), ('up_blocks_1', 'resnets_2', 'time_emb_proj', 'bias'), ('up_blocks_1', 'resnets_2', 'time_emb_proj', 'kernel'), ('up_blocks_1', 'upsamplers_0', 'conv', 'bias'), ('up_blocks_1', 'upsamplers_0', 'conv', 'kernel'), ('up_blocks_2', 'attentions_0', 'norm', 'bias'), ('up_blocks_2', 'attentions_0', 'norm', 'scale'), ('up_blocks_2', 'attentions_0', 'proj_in', 'bias'), ('up_blocks_2', 'attentions_0', 'proj_in', 'kernel'), ('up_blocks_2', 'attentions_0', 'proj_out', 'bias'), ('up_blocks_2', 'attentions_0', 'proj_out', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_2', 'attentions_1', 'norm', 'bias'), ('up_blocks_2', 'attentions_1', 'norm', 'scale'), ('up_blocks_2', 'attentions_1', 'proj_in', 'bias'), ('up_blocks_2', 'attentions_1', 'proj_in', 'kernel'), ('up_blocks_2', 'attentions_1', 'proj_out', 'bias'), ('up_blocks_2', 'attentions_1', 'proj_out', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_2', 'attentions_2', 'norm', 'bias'), ('up_blocks_2', 'attentions_2', 'norm', 'scale'), ('up_blocks_2', 'attentions_2', 'proj_in', 'bias'), ('up_blocks_2', 'attentions_2', 'proj_in', 'kernel'), ('up_blocks_2', 'attentions_2', 'proj_out', 'bias'), ('up_blocks_2', 'attentions_2', 'proj_out', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_2', 'resnets_0', 'conv1', 'bias'), ('up_blocks_2', 'resnets_0', 'conv1', 'kernel'), ('up_blocks_2', 'resnets_0', 'conv2', 'bias'), ('up_blocks_2', 'resnets_0', 'conv2', 'kernel'), ('up_blocks_2', 'resnets_0', 'conv_shortcut', 'bias'), ('up_blocks_2', 'resnets_0', 'conv_shortcut', 'kernel'), ('up_blocks_2', 'resnets_0', 'norm1', 'bias'), ('up_blocks_2', 'resnets_0', 'norm1', 'scale'), ('up_blocks_2', 'resnets_0', 'norm2', 'bias'), ('up_blocks_2', 'resnets_0', 'norm2', 'scale'), ('up_blocks_2', 'resnets_0', 'time_emb_proj', 'bias'), ('up_blocks_2', 'resnets_0', 'time_emb_proj', 'kernel'), ('up_blocks_2', 'resnets_1', 'conv1', 'bias'), ('up_blocks_2', 'resnets_1', 'conv1', 'kernel'), ('up_blocks_2', 'resnets_1', 'conv2', 'bias'), ('up_blocks_2', 'resnets_1', 'conv2', 'kernel'), ('up_blocks_2', 'resnets_1', 'conv_shortcut', 'bias'), ('up_blocks_2', 'resnets_1', 'conv_shortcut', 'kernel'), ('up_blocks_2', 'resnets_1', 'norm1', 'bias'), ('up_blocks_2', 'resnets_1', 'norm1', 'scale'), ('up_blocks_2', 'resnets_1', 'norm2', 'bias'), ('up_blocks_2', 'resnets_1', 'norm2', 'scale'), ('up_blocks_2', 'resnets_1', 'time_emb_proj', 'bias'), ('up_blocks_2', 'resnets_1', 'time_emb_proj', 'kernel'), ('up_blocks_2', 'resnets_2', 'conv1', 'bias'), ('up_blocks_2', 'resnets_2', 'conv1', 'kernel'), ('up_blocks_2', 'resnets_2', 'conv2', 'bias'), ('up_blocks_2', 'resnets_2', 'conv2', 'kernel'), ('up_blocks_2', 'resnets_2', 'conv_shortcut', 'bias'), ('up_blocks_2', 'resnets_2', 'conv_shortcut', 'kernel'), ('up_blocks_2', 'resnets_2', 'norm1', 'bias'), ('up_blocks_2', 'resnets_2', 'norm1', 'scale'), ('up_blocks_2', 'resnets_2', 'norm2', 'bias'), ('up_blocks_2', 'resnets_2', 'norm2', 'scale'), ('up_blocks_2', 'resnets_2', 'time_emb_proj', 'bias'), ('up_blocks_2', 'resnets_2', 'time_emb_proj', 'kernel'), ('up_blocks_2', 'upsamplers_0', 'conv', 'bias'), ('up_blocks_2', 'upsamplers_0', 'conv', 'kernel'), ('up_blocks_3', 'attentions_0', 'norm', 'bias'), ('up_blocks_3', 'attentions_0', 'norm', 'scale'), ('up_blocks_3', 'attentions_0', 'proj_in', 'bias'), ('up_blocks_3', 'attentions_0', 'proj_in', 'kernel'), ('up_blocks_3', 'attentions_0', 'proj_out', 'bias'), ('up_blocks_3', 'attentions_0', 'proj_out', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_3', 'attentions_1', 'norm', 'bias'), ('up_blocks_3', 'attentions_1', 'norm', 'scale'), ('up_blocks_3', 'attentions_1', 'proj_in', 'bias'), ('up_blocks_3', 'attentions_1', 'proj_in', 'kernel'), ('up_blocks_3', 'attentions_1', 'proj_out', 'bias'), ('up_blocks_3', 'attentions_1', 'proj_out', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_3', 'attentions_2', 'norm', 'bias'), ('up_blocks_3', 'attentions_2', 'norm', 'scale'), ('up_blocks_3', 'attentions_2', 'proj_in', 'bias'), ('up_blocks_3', 'attentions_2', 'proj_in', 'kernel'), ('up_blocks_3', 'attentions_2', 'proj_out', 'bias'), ('up_blocks_3', 'attentions_2', 'proj_out', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_3', 'resnets_0', 'conv1', 'bias'), ('up_blocks_3', 'resnets_0', 'conv1', 'kernel'), ('up_blocks_3', 'resnets_0', 'conv2', 'bias'), ('up_blocks_3', 'resnets_0', 'conv2', 'kernel'), ('up_blocks_3', 'resnets_0', 'conv_shortcut', 'bias'), ('up_blocks_3', 'resnets_0', 'conv_shortcut', 'kernel'), ('up_blocks_3', 'resnets_0', 'norm1', 'bias'), ('up_blocks_3', 'resnets_0', 'norm1', 'scale'), ('up_blocks_3', 'resnets_0', 'norm2', 'bias'), ('up_blocks_3', 'resnets_0', 'norm2', 'scale'), ('up_blocks_3', 'resnets_0', 'time_emb_proj', 'bias'), ('up_blocks_3', 'resnets_0', 'time_emb_proj', 'kernel'), ('up_blocks_3', 'resnets_1', 'conv1', 'bias'), ('up_blocks_3', 'resnets_1', 'conv1', 'kernel'), ('up_blocks_3', 'resnets_1', 'conv2', 'bias'), ('up_blocks_3', 'resnets_1', 'conv2', 'kernel'), ('up_blocks_3', 'resnets_1', 'conv_shortcut', 'bias'), ('up_blocks_3', 'resnets_1', 'conv_shortcut', 'kernel'), ('up_blocks_3', 'resnets_1', 'norm1', 'bias'), ('up_blocks_3', 'resnets_1', 'norm1', 'scale'), ('up_blocks_3', 'resnets_1', 'norm2', 'bias'), ('up_blocks_3', 'resnets_1', 'norm2', 'scale'), ('up_blocks_3', 'resnets_1', 'time_emb_proj', 'bias'), ('up_blocks_3', 'resnets_1', 'time_emb_proj', 'kernel'), ('up_blocks_3', 'resnets_2', 'conv1', 'bias'), ('up_blocks_3', 'resnets_2', 'conv1', 'kernel'), ('up_blocks_3', 'resnets_2', 'conv2', 'bias'), ('up_blocks_3', 'resnets_2', 'conv2', 'kernel'), ('up_blocks_3', 'resnets_2', 'conv_shortcut', 'bias'), ('up_blocks_3', 'resnets_2', 'conv_shortcut', 'kernel'), ('up_blocks_3', 'resnets_2', 'norm1', 'bias'), ('up_blocks_3', 'resnets_2', 'norm1', 'scale'), ('up_blocks_3', 'resnets_2', 'norm2', 'bias'), ('up_blocks_3', 'resnets_2', 'norm2', 'scale'), ('up_blocks_3', 'resnets_2', 'time_emb_proj', 'bias'), ('up_blocks_3', 'resnets_2', 'time_emb_proj', 'kernel')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~ModelMixin.to_fp32`] for further information on how to do this.

9.3 Inference

Since TPUs usually have 8 devices working in parallel, we’ll replicate our prompt as many times as devices we have. Then we’ll perform inference on the 8 devices at once, each responsible for generating one image. Thus, we’ll get 8 images in the same amount of time it takes for one chip to generate a single one.

After replicating the prompt, we obtain the tokenized text ids by invoking the prepare_inputs function of the pipeline. The length of the tokenized text is set to 77 tokens, as required by the configuration of the underlying CLIP Text model.

prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
(8, 77)

9.3.1 Replication and parallelization

Model parameters and inputs have to be replicated across the 8 parallel devices we have. The parameters dictionary is replicated using flax.jax_utils.replicate, which traverses the dictionary and changes the shape of the weights so they are repeated 8 times. Arrays are replicated using shard.

p_params = replicate(params)
prompt_ids = shard(prompt_ids)
prompt_ids.shape
(8, 1, 77)

That shape means that each one of the 8 devices will receive as an input a jnp array with shape (1, 77). 1 is therefore the batch size per device. In TPUs with sufficient memory, it could be larger than 1 if we wanted to generate multiple images (per chip) at once.

We are almost ready to generate images! We just need to create a random number generator to pass to the generation function. This is the standard procedure in Flax, which is very serious and opinionated about random numbers – all functions that deal with random numbers are expected to receive a generator. This ensures reproducibility, even when we are training across multiple distributed devices.

The helper function below uses a seed to initialize a random number generator. As long as we use the same seed, we’ll get the exact same results. Feel free to use different seeds when exploring results later in the notebook.

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

We obtain a rng and then “split” it 8 times so each device receives a different generator. Therefore, each device will create a different image, and the full process is reproducible.

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

JAX code can be compiled to an efficient representation that runs very fast. However, we need to ensure that all inputs have the same shape in subsequent calls; otherwise, JAX will have to recompile the code, and we wouldn’t be able to take advantage of the optimized speed.

The Flax pipeline can compile the code for us if we pass jit = True as an argument. It will also ensure that the model runs in parallel in the 8 available devices.

The first time we run the following cell it will take a long time to compile, but subequent calls (even with different inputs) will be much faster. For example, it took more than a minute to compile in a TPU v2-8 when I tested, but then it takes about 7s for future inference runs.

%%time

images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
Wall time: 1min 29s

The returned array has shape (8, 1, 512, 512, 3). We reshape it to get rid of the second dimension and obtain 8 images of 512 × 512 × 3 and then convert them to PIL.

images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

9.3.2 Visualization

Let’s create a helper function to display images in a grid.

def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid
image_grid(images, 2, 4)

9.4 Using different prompts

We don’t have to replicate the same prompt in all the devices. We can do whatever we want: generate 2 prompts 4 times each, or even generate 8 different prompts at once. Let’s do that!

First, we’ll refactor the input preparation code into a handy function:

prompts = [
    "Labrador in the style of Hokusai",
    "Painting of a squirrel skating in New York",
    "HAL-9000 in the style of Van Gogh",
    "Times Square under water, with fish and a dolphin swimming around",
    "Ancient Roman fresco showing a man working on his laptop",
    "Close-up photograph of young black woman against urban background, high quality, bokeh",
    "Armchair in the shape of an avocado",
    "Clown astronaut in space, with Earth in the background"
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

image_grid(images, 2, 4)

9.5 How does parallelization work?

We said before that the diffusers Flax pipeline automatically compiles the model and runs it in parallel on all available devices. We’ll now briefly look inside that process to show how it works.

JAX parallelization can be done in multiple ways. The easiest one revolves around using the jax.pmap function to achieve single-program, multiple-data (SPMD) parallelization. It means we’ll run several copies of the same code, each on different data inputs. More sophisticated approaches are possible, we invite you to go over the JAX documentation and the pjit pages to explore this topic if you are interested!

jax.pmap does two things for us: - Compiles (or jits) the code, as if we had invoked jax.jit(). This does not happen when we call pmap, but the first time the pmapped function is invoked. - Ensures the compiled code runs in parallel in all the available devices.

To show how it works we pmap the _generate method of the pipeline, which is the private method that runs generates images. Please, note that this method may be renamed or removed in future releases of diffusers.

p_generate = pmap(pipeline._generate)

After we use pmap, the prepared function p_generate will conceptually do the following: * Invoke a copy of the underlying function pipeline._generate in each device. * Send each device a different portion of the input arguments. That’s what sharding is used for. In our case, prompt_ids has shape (8, 1, 77, 768). This array will be split in 8 and each copy of _generate will receive an input with shape (1, 77, 768).

We can code _generate completely ignoring the fact that it will be invoked in parallel. We just care about our batch size (1 in this example) and the dimensions that make sense for our code, and don’t have to change anything to make it work in parallel.

The same way as when we used the pipeline call, the first time we run the following cell it will take a while, but then it will be much faster.

%%time

images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape
CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
Wall time: 1min 15s
(8, 1, 512, 512, 3)
images.shape
(8, 1, 512, 512, 3)

We use block_until_ready() to correctly measure inference time, because JAX uses asynchronous dispatch and returns control to the Python loop as soon as it can. You don’t need to use that in your code; blocking will occur automatically when you want to use the result of a computation that has not yet been materialized.