lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.55k stars 643 forks source link

Inference with DeepSpeed #295

Open afiaka87 opened 3 years ago

afiaka87 commented 3 years ago

Trying to run generate.py on a DeepSpeed checkpoint currently breaks. Using inference with DeepSpeed should be relatively simple I think - but I couldn't quite figure it out and realized most of the code I was writing actually just belonged in the the DeepSpeedBackend code which I hadn't yet grokked yet. Anyway; so I don't forget - here is some very very broken code bad code that I had written before giving up last night:

Edit: pretend I never wrote this. 
afiaka87 commented 3 years ago

Looking at train_dalle.py provides some insights from @janEbert prior grokking of Deep Speed. First mistake I'm making here is loading the checkpoint like this:

dalle.load_state_dict(weights)

which is apparently a no-no for DeepSpeed's engine.

afiaka87 commented 3 years ago

Okay - I did things the way they're meant to be done (i believe) @rom1504 @janEbert @mehdidc


if args.fp16:
    engine = deepspeed.init_inference(dalle, dtype=torch.half)
engine = deepspeed.init_inference(dalle)
# training

for epoch in range(EPOCHS):
    if data_sampler:
        data_sampler.set_epoch(epoch)
    for i, (text, images) in enumerate(distr_dl):
        if args.fp16:
            images = images.half()
        text, images = map(lambda t: t.cuda(), (text, images))
        loss = engine(text, images, return_loss=True)

        # update everything
        # ...

        if i % 100 == 0:
            if distr_backend.is_root_worker():
                sample_text = text[:1]
                token_list = sample_text.masked_select(sample_text != 0).tolist()
                decoded_text = tokenizer.decode(token_list)

                image = dalle.generate_images(text[:1], filter_thres=0.9)  # topk sampling at 0.9
                log = {
                    **log,
                }
                if not avoid_model_calls:
                    log['image'] = wandb.Image(image, caption=decoded_text)

And this properly runs backpropagation via some automated strategy parameter I haven't understood yet. It's all undocumented so I'm just reading their code at this point. This may be an instance where their 'auto' policy is inserting values in the range of a -1,1 where the vqgan expects values in the range 0,1? Bit out of my depth on this one.

Traceback (most recent call last):
  File "DALLE-pytorch/train_dalle.py", line 456, in <module>
    loss = engine(text, images, return_loss=True)
  File "/mnt/evo_internal_1TB/.anaconda/envs/sparse/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/mnt/evo_internal_1TB/.anaconda/envs/sparse/lib/python3.7/site-packages/deepspeed/inference/engine.py", line 222, in forward
    return self.module(*inputs, **kwargs)
  File "/mnt/evo_internal_1TB/.anaconda/envs/sparse/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/mnt/evo_internal_1TB/DALLE-pytorch/dalle_pytorch/dalle_pytorch.py", line 486, in forward
    image = self.vae.get_codebook_indices(image)
  File "/mnt/evo_internal_1TB/.anaconda/envs/sparse/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/evo_internal_1TB/DALLE-pytorch/dalle_pytorch/vae.py", line 173, in get_codebook_indices
    _, _, [_, _, indices] = self.model.encode(img)
  File "/mnt/evo_internal_1TB/.anaconda/envs/sparse/lib/python3.7/site-packages/taming_transformers-0.0.1-py3.7.egg/taming/models/vqgan.py", line 54, in encode
    quant, emb_loss, info = self.quantize(h)
  File "/mnt/evo_internal_1TB/.anaconda/envs/sparse/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/mnt/evo_internal_1TB/.anaconda/envs/sparse/lib/python3.7/site-packages/taming_transformers-0.0.1-py3.7.egg/taming/modules/vqvae/quantize.py", line 42, in forward
    torch.sum(self.embedding.weight**2, dim=1) - 2 * \
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
afiaka87 commented 3 years ago

As always, apologies to Jan who I'm sure has already explained this issue ;) I'll admit to some amount of laziness with regard to doing the due diligence on all this 🤷

richcmwang commented 3 years ago

generating seems to be tricky because it seems the deepspeed or DataParallel etc only work through an nn.Module (forward). But the following code works for me to balance the gpus (model trained with stage 1) through dstr_dl:

(distr_dalle, _, distr_dl, _) = distr_backend.distribute(
    args=args,
    model=dalle,
    optimizer=None,
    model_parameters=None,
    training_data=ds if using_deepspeed else dl,
    lr_scheduler=None,
    config_params=deepspeed_config,
)

for i, text in enumerate(distr_dl):
    t = time.time()
    text = text.cuda()
    print(f"generating {i} batch ...batch size {text.shape[0]}")
    image = dalle.generate_images(
        text, filter_thres=0.9)  # topk sampling at 0.9
    sec_per_sample = (time.time() - t) / BATCH_SIZE
    print(i, f'second_per_sample - {sec_per_sample}')

Note that I use dalle because distr_dalle does not work. In my small test case (heads 16/depth 16), both gpus loads are exactly the same.

rom1504 commented 3 years ago

yeah I guess that's useful if you want to generate many sample. That doesn't help to improve the speed of one batch though

richcmwang commented 3 years ago

It does not improve the speed of one batch per GPU, but with 2 (or multiple) GPUs, it does improve the speed. In my test case, the running time ratio for 1 GPU over 2 GPUs is 1.55.

afiaka87 commented 3 years ago

Thanks @richcmwang! I'll work on this later unless you wanna make the PR.

@rom1504 The DeepSpeed docs do indeed claim faster inference with the inference engine. Not sure how though.

janEbert commented 3 years ago

@richcmwang Exactly, they need the forward call which I'm pretty sure is also the reason why FP16 generation fails. They recommended using a simple if-switch in the forward method like do_generations=True. If it's given, don't do the normal forward calculations but just generations and exit. I didn't find the time until now to try it, though.

Aside from inference being parallelizable, I think the biggest benefit is being able to do inference with models that don't fit into memory.

richcmwang commented 3 years ago

@afiaka87 Please feel free to incorporate this. I tried inference but either get incorrect key "checkpoint_path" or unknown type "DeepSpeed" error message. Not sure the doc is accurate.

"checkpoint.json":
{
  "type": "DeepSpeed",
    "version": 0.3,
    "checkpoint_path": "path_to_checkpoints",
}