ml-explore / mlx-examples

Examples in the MLX framework
MIT License
5.84k stars 831 forks source link

Stable diffusion txt2image.py with dtype=float16 generates PNGs with all zero values. #404

Open appthumb opened 7 months ago

appthumb commented 7 months ago

I'm trying to run Stable diffusion txt2image.py with float16 dtype on M1 8GB iMac, since float32 dtype requires more than 8 GB memory. I modified this line of code https://github.com/ml-explore/mlx-examples/blob/e9b32747b424468eabb5a7f0609f275637e1a0c3/stable_diffusion/txt2image.py#L26 into the following:

sd = StableDiffusion(float16=True)

I run it with the following command line:

python3 txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images=1

The execution finishes and produces a 512x512 PNG file 'out.png'.

However, the generated PNG has no content and is fully black. Inspecting the file shows that it contains only zeros. Does the Stable diffusion model work with float16 dtype? Thanks!

angeloskath commented 7 months ago

Hm it seems indeed to be the case. This needs some investigation so it won't be a quick fix unfortunately.

appthumb commented 7 months ago

Thanks for looking into this!

The zeros in the result are due to nan values. I chased down a bit where nan starts to show up, and so far there are two sources of nan:

  1. https://github.com/ml-explore/mlx/blob/95b5fb8245c776d63342e96f387fa01eb84b4492/python/mlx/nn/layers/transformer.py#L107, the hard coded -1e9 is outside of the float16 range, and it causes the returned mask to have nan immediately. The comment above the line actually points out this exactly issue.

  2. The second issue is due to the GroupNorm layer in the UnetBlock2D. In my test, after patching the first issue, the following GroupNorm layer produces the first nan value at the beginning of the third upblock in the unet.

https://github.com/ml-explore/mlx-examples/blob/e9b32747b424468eabb5a7f0609f275637e1a0c3/stable_diffusion/stable_diffusion/unet.py#L138

This is likely due to the calculation of mx.var, which can be too large for float16. One way to improve the situation is to use float32 for mx.var internally in the GroupNorm (and potentially other normalization layers).

zetyquickly commented 5 months ago

float16 is messed up (https://github.com/ml-explore/mlx/issues/483)

when reshape is done on some values, they might end up different. like mentioned in the issue above.

did you try with bfloat16?

zetyquickly commented 5 months ago

regarding the GroupNorm, I've tried to match outputs from pytorch code and from the mlx code, even with float32 it gives different results.

sapmle:

import mlx.core as mx
import torch
import mlx.nn as nn
import numpy as np

N, C, X = 1, 4, 1
a_np = np.random.normal(size=(N, X, C))

a_mx = mx.array(a_np, dtype=mx.float32)
norm_num_groups, pytorch_compatible = 2, True
gnorm = nn.GroupNorm(norm_num_groups, C, pytorch_compatible)
b_mx = gnorm(a_mx)

a_torch = torch.tensor(a_np, dtype=torch.float32)
gnorm_torch = torch.nn.GroupNorm(norm_num_groups, C)
b_torch = gnorm_torch(a_torch.reshape(N, C, X))
print(np.allclose(b_torch.detach().numpy(), np.array(b_mx, copy=False)))
# Output: false
appthumb commented 5 months ago

I didn't try bfloat16, but I think bfloat16 would be even worse because it uses even less bits for significant digits.

Regarding to matching PyTorch GroupNorm on float32, the results are actually close, so if you give a larger tolerance value it should work. I got float16 GroupNorm to be "close" to PyTorch with really large tolerance (5e-3). I think this is probably okay for inference. For this to work, the GroupNorm has to use float32 internally, otherwise it will overflow when calculating mx.var (and probably mx.mean).

This is the approach taken by the update of the example code, which seems to work now. See https://github.com/ml-explore/mlx-examples/commit/3a9e6c3f701baa2ab2745dd7641cec003024e7ed#diff-8e457d43b158ae71a0cbd888339296a6682608cc38b89581cf24e3464f3464ed, esp. lines like https://github.com/ml-explore/mlx-examples/blob/c68aa3c7c381b39a563a9102c5925f9d3b1523a8/stable_diffusion/stable_diffusion/unet.py#L113 which lifts to float32 when doing GroupNorm. This seems to work.