Yann-CV / vintage-models

Pytorch personal implementation of the most famous neural network models
GNU General Public License v3.0
0 stars 0 forks source link

All loss should be computed outside of the model in the LightningModule #20

Open Yann-CV opened 2 weeks ago

Yann-CV commented 2 weeks ago

The classifier is computing the loss within the lightning module while the generator is not.

Because a model can be used for different purposes it makes sense to do it in the lightning module

dosubot[bot] commented 2 weeks ago

To refactor the generator's loss computation to be done outside of the model within the LightningModule, you can follow a similar approach to the one used for the classifier. Here is how you can refactor the ImageGenerator class:

  1. Move the loss computation to a separate method in the LightningModule.
  2. Call this method in the training_step, validation_step, and test_step methods.

Here is the refactored code for the ImageGenerator class:

from typing import Tuple
from lightning import LightningModule
from torch import Tensor
from torch.nn.functional import binary_cross_entropy
from torch.optim import Adam
from vintage_models.autoencoder.vae.vae import Vae

class ImageGenerator(LightningModule):
    """Lightning module for image generation experiments."""

    def __init__(self, model: Vae) -> None:
        super().__init__()
        self.model = model
        self.optimizer = Adam(self.model.parameters(), lr=1e-3)
        self.training_step_outputs: list[Tensor] = []
        self.validation_step_outputs: list[Tensor] = []
        self.test_step_outputs: list[Tensor] = []

    def compute_loss(self, x: Tensor) -> Tensor:
        reconstructed = self.model(x)
        vector_size = x.size(-1) * x.size(-2)
        reconstruction_loss = (
            binary_cross_entropy(
                reconstructed,
                x,
                reduction="none",
            )
            .reshape(-1, vector_size)
            .sum(dim=1)
        )

        mean, log_var = self.model.encoder.compute_mean_and_log_var(x)
        kl_div = -0.5 * (1 + log_var - log_var.exp() - mean.pow(2)).sum(dim=1)

        loss = kl_div.mean() + reconstruction_loss.mean()
        return loss

    def training_step(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
        data, _ = batch
        loss = self.compute_loss(data)
        self.training_step_outputs.append(loss)
        return loss

    def on_train_epoch_end(self) -> None:
        self.log(
            "training_loss",
            self.training_step_outputs[-1].item(),
            prog_bar=True,
            logger=True,
        )
        self.training_step_outputs.clear()

    def validation_step(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
        data, _ = batch
        loss = self.compute_loss(data)
        self.validation_step_outputs.append(loss)
        return loss

    def on_validation_epoch_end(self) -> None:
        self.log(
            "validation_loss",
            self.validation_step_outputs[-1].item(),
            prog_bar=True,
            logger=True,
        )
        self.validation_step_outputs.clear()

    def test_step(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
        data, _ = batch
        generated = self.model.generate(1)
        self.test_step_outputs.append(generated)
        return generated

    def on_test_epoch_end(self) -> None:
        assert isinstance(self.logger, MLFlowLogger)
        saving_path = Path("/storage/ml") / str(self.logger.experiment_id) / "generated"
        if not saving_path.exists():
            saving_path.mkdir(parents=True, exist_ok=True)

        for idx, image in enumerate(self.test_step_outputs):
            self.logger.experiment.log_image(
                self.logger.run_id,
                image.squeeze().cpu().numpy(),
                f"generated_{idx}.png",
            )

        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        return self.optimizer

In this refactored code, the compute_loss method is responsible for computing the loss. This method is called in the training_step, validation_step, and test_step methods to ensure that all loss computations are done outside of the model [1][2].

To continue talking to Dosu, mention @dosu.