crowsonkb / k-diffusion

Karras et al. (2022) diffusion models for PyTorch
MIT License
2.21k stars 371 forks source link

sample.py: Error while deserializing header: HeaderTooLarge #99

Closed drscotthawley closed 5 months ago

drscotthawley commented 5 months ago

Hi Kat et al! Thanks for making this great repo available. I've trained a few models so far and it's working great. But now I've tried to sample from a checkpoint and I'm getting an unexpected error.

I trained the MNIST demo using

/runs/shawley/k-diffusion/mnist/mnist_trans$  ~/diffusion/k-diffusion/train.py --config ~/diffusion/k-diffusion/configs/config_mnist_transformer.json

Which over time wrote a bunch of checkpoint files...

/runs/shawley/k-diffusion/mnist/mnist_trans$ ls *.pth
model_00010000.pth  model_00040000.pth  model_00070000.pth  model_00100000.pth
model_00020000.pth  model_00050000.pth  model_00080000.pth  model_00110000.pth
model_00030000.pth  model_00060000.pth  model_00090000.pth  model_00120000.pth

When I try to run sample.py using the last checkpoint, I get this...

/runs/shawley/k-diffusion/mnist/mnist_trans$ ~/diffusion/k-diffusion/sample.py --batch-size 10 --checkpoint model_00120000.pth --config ~/diffusion/k-diffusion/configs/config_mnist_transformer.json -n 10 --prefix newsample_

Using device: cuda
Traceback (most recent call last):
  File "/home/shawley/diffusion/k-diffusion/sample.py", line 75, in <module>
    main()
  File "/home/shawley/diffusion/k-diffusion/sample.py", line 44, in main
    inner_model.load_state_dict(safetorch.load_file(args.checkpoint))
  File "/home/shawley/envs/hs/lib/python3.10/site-packages/safetensors/torch.py", line 308, in load_file
    with safe_open(filename, framework="pt", device=device) as f:
safetensors_rust.SafetensorError: Error while deserializing header: HeaderTooLarge

Is it absolutely necessary to create a "slim checkpoint" before sampling? I have not done that (yet).

drscotthawley commented 5 months ago

Ok, if I run convert_for_inference.py first, then that works but then sample.py generates a different error:

/runs/shawley/k-diffusion/mnist/mnist_trans$ ~/diffusion/k-diffusion/convert_for_inference.py model_00120000.pth  --config ~/diffusion/k-diffusion/configs/config_mnist_transformer.json --output 120000_slim.pth
Loading training checkpoint model_00120000.pth...
Saving inference checkpoint to 120000_slim.pth...
/runs/shawley/k-diffusion/mnist/mnist_trans$ ~/diffusion/k-diffusion/sample.py --batch-size 10 --checkpoint 120000_slim.pth --config ~/diffusion/k-diffusion/configs/config_mnist_transformer.json -n 10 --prefix newsample_
Using device: cuda
Parameters: 9187360
Sampling...
  0%|                                                                                                   | 0/50 [00:00<?, ?it/s]
  0%|                                                                                                    | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/shawley/diffusion/k-diffusion/sample.py", line 75, in <module>
    main()
  File "/home/shawley/diffusion/k-diffusion/sample.py", line 69, in main
    run()
  File "/home/shawley/envs/hs/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/shawley/diffusion/k-diffusion/sample.py", line 62, in run
    x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size)
  File "/home/shawley/github/k-diffusion/k_diffusion/evaluation.py", line 86, in compute_features
    samples = sample_fn(cur_batch_size)[:cur_batch_size]
  File "/home/shawley/diffusion/k-diffusion/sample.py", line 60, in sample_fn
    x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process)
  File "/home/shawley/envs/hs/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/shawley/github/k-diffusion/k_diffusion/sampling.py", line 267, in sample_lms
    denoised = model(x, sigmas[i] * s_in, **extra_args)
  File "/home/shawley/envs/hs/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/shawley/envs/hs/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/shawley/github/k-diffusion/k_diffusion/layers.py", line 90, in forward
    return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip
  File "/home/shawley/envs/hs/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/shawley/envs/hs/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/shawley/github/k-diffusion/k_diffusion/models/image_transformer_v2.py", line 711, in forward
    raise ValueError("class_cond must be specified if num_classes > 0")
ValueError: class_cond must be specified if num_classes > 0

Where does the class_cond go? (I don't see it in sample.py, and the config file already has "cond_dropout_rate": 0.1.)

drscotthawley commented 5 months ago

Ok, so class_cond can be specified via extra_args in the call to sample_lms. So, e.g., to generate a whole bunch of MNIST digits of the number "5", the following modification of sample.py line 60 works:

            x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args={'class_cond':torch.tensor([5]).to(device)})

So, what I'm learning is that you must run convert_for_inference.py first, and you may need to modify sample.py to specify a class_cond.

Does that sound right? If so, Close.

nom commented 5 months ago

You can also just do a regular torch.load instead of the safetensors version.

inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device)
ckpt = torch.load(args.checkpoint, map_location='cpu')
inner_model.load_state_dict(ckpt['model'])

Doesn't require conversion.

drscotthawley commented 5 months ago

Thank you @nom! Closing.