ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.29k stars 897 forks source link

Tried to use SDXL-turbo #44

Open ageorgios opened 11 months ago

ageorgios commented 11 months ago

I made the changes to

_DEFAULT_MODEL = "stabilityai/sdxl-turbo"
_MODELS = {
    # See https://huggingface.co/stabilityai/sdxl-turbo for the model details and license
    "stabilityai/sdxl-turbo": {
        "unet_config": "unet/config.json",
        "unet": "unet/diffusion_pytorch_model.safetensors",
        "text_encoder_config": "text_encoder/config.json",
        "text_encoder": "text_encoder/model.safetensors",
        "vae_config": "vae/config.json",
        "vae": "vae/diffusion_pytorch_model.safetensors",
        "diffusion_config": "scheduler/scheduler_config.json",
        "tokenizer_vocab": "tokenizer/vocab.json",
        "tokenizer_merges": "tokenizer/merges.txt",
    }
}

but I get the error

% python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 1
  0%|                                                                                                                                           | 0/50 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "txt2image.py", line 36, in <module>
    for x_t in tqdm(latents, total=args.steps):
  File "/Users/ageorgios/.pyenv/versions/3.8.18/lib/python3.8/site-packages/tqdm/std.py", line 1182, in __iter__
    for obj in iterable:
  File "/Users/ageorgios/Models/mlx-examples/stable_diffusion_turbo/stable_diffusion/__init__.py", line 84, in generate_latents
    eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
  File "/Users/ageorgios/Models/mlx-examples/stable_diffusion_turbo/stable_diffusion/unet.py", line 395, in __call__
    x, res = block(
  File "/Users/ageorgios/Models/mlx-examples/stable_diffusion_turbo/stable_diffusion/unet.py", line 252, in __call__
    x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
  File "/Users/ageorgios/Models/mlx-examples/stable_diffusion_turbo/stable_diffusion/unet.py", line 117, in __call__
    x = block(x, encoder_x, attn_mask, encoder_attn_mask)
  File "/Users/ageorgios/Models/mlx-examples/stable_diffusion_turbo/stable_diffusion/unet.py", line 70, in __call__
    y = self.attn2(y, memory, memory, memory_mask)
  File "/Users/ageorgios/.pyenv/versions/3.8.18/lib/python3.8/site-packages/mlx/nn/layers/transformer.py", line 68, in __call__
    keys = self.key_proj(keys)
  File "/Users/ageorgios/.pyenv/versions/3.8.18/lib/python3.8/site-packages/mlx/nn/layers/linear.py", line 33, in __call__
    x = x @ self.weight.T
ValueError: [matmul] Last dimension of first input with shape (2,13,768) must match second to last dimension of second input with shape (2048,320).
angeloskath commented 11 months ago

I have to admit that our parsing of the configs and instantiation of the corresponding models is far from robust so you probably to investigate what is not loaded properly or how the model is instantiated incorrectly.

For what is worth (it may be obvious), it seems the key_proj expects the keys to be of size 2048 but they are of size 768. So maybe the layer just before was loaded incorrectly?

LeaveNhA commented 11 months ago

@ageorgios, can you fork and make the changes? So we can collaborate on it. How much RAM do I need to do it, BTW.

ArthurMynl commented 11 months ago

Hey @ageorgios, I would be very interested about a SDXL version optimized for Apple Silicon using MLX, if you manage to make this work, can you ping me? 😄

USMCM1A1 commented 8 months ago

@ArthurMynl @ageorgios I would also be super interested in an SDXL (vice turbo) version.