Closed drscotthawley closed 5 months ago
Ok, if I run convert_for_inference.py
first, then that works but then sample.py
generates a different error:
/runs/shawley/k-diffusion/mnist/mnist_trans$ ~/diffusion/k-diffusion/convert_for_inference.py model_00120000.pth --config ~/diffusion/k-diffusion/configs/config_mnist_transformer.json --output 120000_slim.pth
Loading training checkpoint model_00120000.pth...
Saving inference checkpoint to 120000_slim.pth...
/runs/shawley/k-diffusion/mnist/mnist_trans$ ~/diffusion/k-diffusion/sample.py --batch-size 10 --checkpoint 120000_slim.pth --config ~/diffusion/k-diffusion/configs/config_mnist_transformer.json -n 10 --prefix newsample_
Using device: cuda
Parameters: 9187360
Sampling...
0%| | 0/50 [00:00<?, ?it/s]
0%| | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/shawley/diffusion/k-diffusion/sample.py", line 75, in <module>
main()
File "/home/shawley/diffusion/k-diffusion/sample.py", line 69, in main
run()
File "/home/shawley/envs/hs/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/shawley/diffusion/k-diffusion/sample.py", line 62, in run
x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size)
File "/home/shawley/github/k-diffusion/k_diffusion/evaluation.py", line 86, in compute_features
samples = sample_fn(cur_batch_size)[:cur_batch_size]
File "/home/shawley/diffusion/k-diffusion/sample.py", line 60, in sample_fn
x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process)
File "/home/shawley/envs/hs/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/shawley/github/k-diffusion/k_diffusion/sampling.py", line 267, in sample_lms
denoised = model(x, sigmas[i] * s_in, **extra_args)
File "/home/shawley/envs/hs/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/shawley/envs/hs/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/shawley/github/k-diffusion/k_diffusion/layers.py", line 90, in forward
return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip
File "/home/shawley/envs/hs/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/shawley/envs/hs/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/shawley/github/k-diffusion/k_diffusion/models/image_transformer_v2.py", line 711, in forward
raise ValueError("class_cond must be specified if num_classes > 0")
ValueError: class_cond must be specified if num_classes > 0
Where does the class_cond
go? (I don't see it in sample.py
, and the config file already has "cond_dropout_rate": 0.1
.)
Ok, so class_cond
can be specified via extra_args
in the call to sample_lms
. So, e.g., to generate a whole bunch of MNIST digits of the number "5", the following modification of sample.py
line 60 works:
x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args={'class_cond':torch.tensor([5]).to(device)})
So, what I'm learning is that you must run convert_for_inference.py
first, and you may need to modify sample.py
to specify a class_cond
.
Does that sound right? If so, Close.
You can also just do a regular torch.load
instead of the safetensors version.
inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device)
ckpt = torch.load(args.checkpoint, map_location='cpu')
inner_model.load_state_dict(ckpt['model'])
Doesn't require conversion.
Thank you @nom! Closing.
Hi Kat et al! Thanks for making this great repo available. I've trained a few models so far and it's working great. But now I've tried to sample from a checkpoint and I'm getting an unexpected error.
I trained the MNIST demo using
Which over time wrote a bunch of checkpoint files...
When I try to run
sample.py
using the last checkpoint, I get this...Is it absolutely necessary to create a "slim checkpoint" before sampling? I have not done that (yet).