Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.96k stars 3.35k forks source link

manual_backward and .backward() have different behaviour. #18740

Open roedoejet opened 11 months ago

roedoejet commented 11 months ago

Bug description

I expected manual_backward and .backward to perform backward propagation in the same way, but when I use self.manual_backward it results in a number of unused parameters. If I use .backward then the problem doesn't occur.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

def training_step(self, batch, batch_idx):
        x, y, _, y_mel = batch
        y = y.unsqueeze(1)
        # x.size() & y_mel.size() = [batch_size, n_mels=80, n_frames=32]
        # y.size() = [batch_size, segment_size=8192]
        optim_g, optim_d = self.optimizers()
        scheduler_g, scheduler_d = self.lr_schedulers()
        # generate waveform
        if self.config.model.istft_layer:
            mag, phase = self(x)
            generated_wav = self.inverse_spectral_transform(
                mag * torch.exp(phase * 1j)
            ).unsqueeze(-2)
        else:
            generated_wav = self(x)

        # create mel
        generated_mel_spec = dynamic_range_compression_torch(
            self.spectral_transform(generated_wav).squeeze(1)[:, :, 1:]
        )
        # train discriminators
        optim_d.zero_grad()
        # MPD
        y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, generated_wav.detach())
        if self.use_gradient_penalty:
            gp_f = self.compute_gradient_penalty(y.data, generated_wav.detach().data, self.mpd)
        else:
            gp_f = None
        loss_disc_f, _, _ = self.discriminator_loss(y_df_hat_r, y_df_hat_g, gp=gp_f)
        self.log("training/disc/mpd_loss", loss_disc_f, prog_bar=False)
        # MSD
        y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, generated_wav.detach())
        loss_disc_s, _, _ = self.discriminator_loss(y_ds_hat_r, y_ds_hat_g, gp=gp_s)
        self.log("training/disc/msd_loss", loss_disc_s, prog_bar=False)
        # calculate loss
        disc_loss_total = loss_disc_s + loss_disc_f
        # manual optimization because Pytorch Lightning 2.0+ doesn't handle automatic optimization for multiple optimizers
        # this works
        disc_loss_total.backward()
        # this does not
        # self.manual_backward(disc_loss_total
        optim_d.step()
        scheduler_d.step()
        # log discriminator loss
        self.log("training/disc/d_loss_total", disc_loss_total, prog_bar=False)

        # train generator
        optim_g.zero_grad()
        # calculate loss
        _, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, generated_wav)
        _, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, generated_wav)
        loss_fm_f = self.feature_loss(fmap_f_r, fmap_f_g)
        loss_fm_s = self.feature_loss(fmap_s_r, fmap_s_g)
        loss_gen_f, _ = self.generator_loss(
            y_df_hat_g, gp=self.use_gradient_penalty
        )
        loss_gen_s, _ = self.generator_loss(
            y_ds_hat_g, gp=self.use_gradient_penalty
        )
        self.log("training/gen/loss_fmap_f", loss_fm_f, prog_bar=False)
        self.log("training/gen/loss_fmap_s", loss_fm_s, prog_bar=False)
        self.log("training/gen/loss_gen_f", loss_gen_f, prog_bar=False)
        self.log("training/gen/loss_gen_s", loss_gen_s, prog_bar=False)
        loss_mel = F.l1_loss(y_mel, generated_mel_spec) * 45
        gen_loss_total = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
        # manual optimization because Pytorch Lightning 2.0+ doesn't handle automatic optimization for multiple optimizers
        gen_loss_total.backward()
        optim_g.step()
        scheduler_g.step()
        # log generator loss
        self.log("training/gen/gen_loss_total", gen_loss_total, prog_bar=True)
        self.log("training/gen/mel_spec_error", loss_mel / 45, prog_bar=False)

I caught this by adding an on_after_backward method. When I use self.manual_backward(disc_loss_total) or self.manual_backward(gen_loss_total) then I get a bunch of parameters with p.grad == None but when I use disc_loss_total.backward() everything works fine.

Error messages and logs

# Error messages and logs here please

Environment

Current environment * CUDA: - GPU: - Tesla V100-SXM2-16GB - available: True - version: 11.7 * Lightning: - lightning: 2.0.4 - lightning-cloud: 0.5.39 - lightning-utilities: 0.9.0 - pytorch-lightning: 2.0.9.post0 - torch: 2.0.1+cu117 - torchaudio: 2.0.2+cu117 - torchmetrics: 1.2.0 * Packages: - absl-py: 2.0.0 - aiohttp: 3.8.5 - aiosignal: 1.3.1 - aniso8601: 9.0.1 - annotated-types: 0.5.0 - anyio: 3.7.1 - anytree: 2.9.0 - arrow: 1.3.0 - async-timeout: 4.0.3 - attrs: 23.1.0 - audioread: 3.0.1 - beautifulsoup4: 4.12.2 - bidict: 0.22.1 - black: 22.12.0 - blessed: 1.20.0 - cachetools: 5.3.1 - certifi: 2023.7.22 - cffi: 1.16.0 - cfgv: 3.4.0 - charset-normalizer: 3.3.0 - click: 8.1.7 - clipdetect: 0.1.3 - cmake: 3.27.6 - colorama: 0.4.6 - coloredlogs: 14.0 - contourpy: 1.1.1 - croniter: 1.3.15 - cycler: 0.12.0 - cython: 3.0.3 - dateutils: 0.6.12 - decorator: 5.1.1 - deepdiff: 6.6.0 - distlib: 0.3.7 - dnspython: 2.3.0 - editdistance: 0.6.2 - einops: 0.5.0 - et-xmlfile: 1.1.0 - eventlet: 0.33.3 - everyvoice: 0.1.20231005 - exceptiongroup: 1.1.3 - fastapi: 0.103.2 - filelock: 3.12.4 - flake8: 6.1.0 - flask: 2.2.5 - flask-cors: 4.0.0 - flask-restful: 0.3.10 - flask-socketio: 5.3.6 - flask-talisman: 1.1.0 - fonttools: 4.43.0 - frozenlist: 1.4.0 - fsspec: 2023.9.2 - g2p: 1.1.20230822 - gitlint-core: 0.19.1 - google-auth: 2.23.2 - google-auth-oauthlib: 1.0.0 - greenlet: 3.0.0 - grpcio: 1.59.0 - h11: 0.14.0 - humanfriendly: 10.0 - identify: 2.5.30 - idna: 3.4 - importlib-metadata: 6.8.0 - iniconfig: 2.0.0 - inquirer: 3.1.3 - isort: 5.12.0 - itsdangerous: 2.1.2 - jinja2: 3.1.2 - joblib: 1.3.2 - jsonschema: 4.19.1 - jsonschema-specifications: 2023.7.1 - kiwisolver: 1.4.5 - librosa: 0.9.2 - lightning: 2.0.4 - lightning-cloud: 0.5.39 - lightning-utilities: 0.9.0 - lit: 17.0.2 - llvmlite: 0.41.0 - loguru: 0.6.0 - markdown: 3.4.4 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - matplotlib: 3.6.0 - mccabe: 0.7.0 - mdurl: 0.1.2 - merge-args: 0.1.5 - mpmath: 1.3.0 - multidict: 6.0.4 - munkres: 1.1.4 - mypy: 1.5.1 - mypy-extensions: 1.0.0 - networkx: 2.8.4 - nltk: 3.7 - nodeenv: 1.8.0 - numba: 0.58.0 - numpy: 1.25.2 - oauthlib: 3.2.2 - openpyxl: 3.1.2 - ordered-set: 4.1.0 - packaging: 23.2 - pandas: 1.4.4 - panphon: 0.20.0 - pathspec: 0.11.2 - pillow: 10.0.1 - pip: 23.2.1 - platformdirs: 3.11.0 - pluggy: 1.3.0 - pooch: 1.7.0 - pre-commit: 3.4.0 - prompt-toolkit: 3.0.39 - protobuf: 4.24.4 - psutil: 5.9.5 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pycodestyle: 2.11.0 - pycountry: 22.3.5 - pycparser: 2.21 - pydantic: 2.4.2 - pydantic-core: 2.10.1 - pyflakes: 3.1.0 - pygments: 2.16.1 - pyjwt: 2.8.0 - pympi-ling: 1.70.2 - pyparsing: 3.1.1 - pysdtw: 0.0.5 - pytest: 7.4.2 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-engineio: 4.7.1 - python-multipart: 0.0.6 - python-socketio: 5.9.0 - pytorch-lightning: 2.0.9.post0 - pytz: 2023.3.post1 - pyworld: 0.3.4 - pyyaml: 6.0.1 - questionary: 1.10.0 - readchar: 4.0.5 - referencing: 0.30.2 - regex: 2023.10.3 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - resampy: 0.4.2 - rich: 13.6.0 - rpds-py: 0.10.4 - rsa: 4.9 - scikit-learn: 1.3.1 - scipy: 1.11.3 - setuptools: 59.5.0 - sh: 2.0.6 - shellingham: 1.5.3 - simple-term-menu: 1.5.2 - simple-websocket: 1.0.0 - six: 1.16.0 - sniffio: 1.3.0 - soundfile: 0.12.1 - soupsieve: 2.5 - starlette: 0.27.0 - starsessions: 1.3.0 - sympy: 1.12 - tabulate: 0.8.10 - tensorboard: 2.14.1 - tensorboard-data-server: 0.7.1 - text-unidecode: 1.3 - threadpoolctl: 3.2.0 - tomli: 2.0.1 - torch: 2.0.1+cu117 - torchaudio: 2.0.2+cu117 - torchmetrics: 1.2.0 - tqdm: 4.66.1 - traitlets: 5.11.2 - triton: 2.0.0 - typer: 0.9.0 - types-python-dateutil: 2.8.19.14 - types-pyyaml: 6.0.12.12 - types-requests: 2.31.0.8 - types-setuptools: 68.2.0.0 - types-tabulate: 0.8.11 - typing-extensions: 4.8.0 - unicodecsv: 0.14.1 - urllib3: 2.0.6 - uvicorn: 0.23.2 - virtualenv: 20.24.5 - wcwidth: 0.2.8 - websocket-client: 1.6.3 - websockets: 11.0.3 - werkzeug: 2.2.3 - wheel: 0.41.2 - wsproto: 1.2.0 - yarl: 1.9.2 - zipp: 3.17.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.18 - release: 4.15.0-204-generic - version: #215-Ubuntu SMP Fri Jan 20 18:24:59 UTC 2023

More info

No response

awaelchli commented 11 months ago

Hey @roedoejet Is this reproducible in our GAN examples? https://github.com/Lightning-AI/lightning/blob/master/examples/pytorch/domain_templates/generative_adversarial_net.py

roedoejet commented 11 months ago

Hey @roedoejet Is this reproducible in our GAN examples? https://github.com/Lightning-AI/lightning/blob/master/examples/pytorch/domain_templates/generative_adversarial_net.py

I'm away now until next week, but I will give it a shot then and post and update here. Thanks.

roedoejet commented 11 months ago

Unfortunately this is not reproducible in the above-posted GAN example in my environment. I will try to poke around a bit more to see if I can find a minimal example.

awaelchli commented 11 months ago

@roedoejet Thanks for looking at it. Due to priorities, I won't have the bandwidth to search for the bug. If you find a way to reproduce this in a code example we can study, that would be invaluable.