rosinality / vq-vae-2-pytorch

Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch
Other
1.62k stars 272 forks source link

Stuck at epoch 1 iter 1 when train vqvae with multi-gpu #70

Open JamesZhutheThird opened 2 years ago

JamesZhutheThird commented 2 years ago

The single-gpu training process works just fine for me, and the output samples are satisfactory. However when I set --n_gpu 2 or --n_gpu 4, the training process will get stuck at the beginning(Epoch 1 Iter 1). And the time cost of this very first iter (34 seconds) is much longer than that in single-gpu training (3 seconds per iter).

图片

I would be grateful if someone could help me to see what might be wrong with this.

xuyanging commented 1 year ago

I also have the same problem, have you worked out ?

JamesZhutheThird commented 1 year ago

Well, I think I solved this by adding torch.cuda.empty_cache() at the end of each iteration.

ekyy2 commented 1 year ago

Well, I think I solved this by adding torch.cuda.empty_cache() at the end of each iteration.

May I ask where did you add torch.cuda.empty_cache()? I am experiencing the same problem. Thank you.

JamesZhutheThird commented 1 year ago

Original

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]
            # ...
            if i % 100 == 0:
                model.eval()
                # ...
                with torch.no_grad():
                    out, _ = model(sample)
                # ...
                model.train()

Fixed

        if dist.is_primary():
            with torch.no_grad():
                lr = optimizer.param_groups[0]["lr"]
                # ...
                if i % 100 == 0:
                    model.eval()
                    # ...
                    out, _ = model(sample)
                    # ...
                    torch.cuda.empty_cache()
ekyy2 commented 1 year ago

Original

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]
            # ...
            if i % 100 == 0:
                model.eval()
                # ...
                with torch.no_grad():
                    out, _ = model(sample)
                # ...
                model.train()

Fixed

        if dist.is_primary():
            with torch.no_grad():
                lr = optimizer.param_groups[0]["lr"]
                # ...
                if i % 100 == 0:
                    model.eval()
                    # ...
                    out, _ = model(sample)
                    # ...
                    torch.cuda.empty_cache()

@JamesZhutheThird Thank you so much for your quick reply. I believe that the model is still stuck at calculating recon_loss.item() for the part_mse_sum after I made the following changes:

    'if dist.is_primary():
        with torch.no_grad():
            lr = optimizer.param_groups[0]["lr"]

            loader.set_description(
                (
                    f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
                    f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
                    f"lr: {lr:.5f}"
                )
            )

            if i % 100 == 0:
                model.eval()

                sample = img[:sample_size]

                # with torch.no_grad():
                out, _ = model(sample)

                utils.save_image(
                    torch.cat([sample, out], 0),
                    f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
                    nrow=sample_size,
                    normalize=True,
                    range=(-1, 1),
                )

                model.train()
                torch.cuda.empty_cache()`

I am not sure whether I understood the changes to be made correctly, so your advice would be much appreciated. Thanks.

JamesZhutheThird commented 1 year ago

perhaps you can try moving model.train() to the begining of each iter

    for i, (img, label) in enumerate(loader):
        model.train()
        model.zero_grad()

I don't see anything else different from my code, but since I have changed other parts of the code and added extra functions, I'm not sure whether they are related to this bug. Anyway, please keep in touch with me about this. @ekyy2

ekyy2 commented 1 year ago

I tried the change you suggested but it does not seem to work. No worries. As you said, it could be something else. Maybe it could be that I am using Cuda 11.6, but using the PyTorch combination for Cuda 11.3 (pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113). Another is that the blowing up of the latent error is mentioned here: https://github.com/rosinality/vq-vae-2-pytorch/issues/65#issue-870480215, but it is weird that it only happens for multiple GPUs. Anyways, keep in touch.

JamesZhutheThird commented 1 year ago
python                3.9.13
torch                   1.11.0+cu113
torchvision          0.12.0+cu113
python                3.9.13 
torch                   1.12.0+cu113
torchaudio          0.12.0+cu113
torchvision          0.13.0+cu113

I've tested on two environments with different pytorch versions with CUDA11.2. I am pretty sure the exact versions are not so important.

btw I checked files in ./distributed and they are the same as those in this repo.

subminu commented 12 months ago

For someone who has this problem, I share my solution. please set the find_unused_parameters as True in the DDP. This problem has been triggered by the unused quantized vector. (Each DDP process waits other process until all grad of learnable model parameters are used)

if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
            find_unused_parameters=True, # here
        )

@ekyy2, I hope this solution work for you.

ekyy2 commented 11 months ago

For someone who has this problem, I share my solution. please set the find_unused_parameters as True in the DDP. This problem has been triggered by the unused quantized vector. (Each DDP process waits other process until all grad of learnable model parameters are used)

if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
            find_unused_parameters=True, # here
        )

@ekyy2, I hope this solution work for you.

I can confirm that this works. Thank you so much! The issue can be closed.

peylnog commented 7 months ago

For someone who has this problem, I share my solution. please set the find_unused_parameters as True in the DDP. This problem has been triggered by the unused quantized vector. (Each DDP process waits other process until all grad of learnable model parameters are used)

if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
            find_unused_parameters=True, # here
        )

@ekyy2, I hope this solution work for you.

still problem