Closed Yann-CV closed 3 months ago
To refactor the generator's loss computation to be done outside of the model within the LightningModule
, you can follow a similar approach to the one used for the classifier. Here is how you can refactor the ImageGenerator
class:
LightningModule
.training_step
, validation_step
, and test_step
methods.Here is the refactored code for the ImageGenerator
class:
from typing import Tuple
from lightning import LightningModule
from torch import Tensor
from torch.nn.functional import binary_cross_entropy
from torch.optim import Adam
from vintage_models.autoencoder.vae.vae import Vae
class ImageGenerator(LightningModule):
"""Lightning module for image generation experiments."""
def __init__(self, model: Vae) -> None:
super().__init__()
self.model = model
self.optimizer = Adam(self.model.parameters(), lr=1e-3)
self.training_step_outputs: list[Tensor] = []
self.validation_step_outputs: list[Tensor] = []
self.test_step_outputs: list[Tensor] = []
def compute_loss(self, x: Tensor) -> Tensor:
reconstructed = self.model(x)
vector_size = x.size(-1) * x.size(-2)
reconstruction_loss = (
binary_cross_entropy(
reconstructed,
x,
reduction="none",
)
.reshape(-1, vector_size)
.sum(dim=1)
)
mean, log_var = self.model.encoder.compute_mean_and_log_var(x)
kl_div = -0.5 * (1 + log_var - log_var.exp() - mean.pow(2)).sum(dim=1)
loss = kl_div.mean() + reconstruction_loss.mean()
return loss
def training_step(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
data, _ = batch
loss = self.compute_loss(data)
self.training_step_outputs.append(loss)
return loss
def on_train_epoch_end(self) -> None:
self.log(
"training_loss",
self.training_step_outputs[-1].item(),
prog_bar=True,
logger=True,
)
self.training_step_outputs.clear()
def validation_step(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
data, _ = batch
loss = self.compute_loss(data)
self.validation_step_outputs.append(loss)
return loss
def on_validation_epoch_end(self) -> None:
self.log(
"validation_loss",
self.validation_step_outputs[-1].item(),
prog_bar=True,
logger=True,
)
self.validation_step_outputs.clear()
def test_step(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
data, _ = batch
generated = self.model.generate(1)
self.test_step_outputs.append(generated)
return generated
def on_test_epoch_end(self) -> None:
assert isinstance(self.logger, MLFlowLogger)
saving_path = Path("/storage/ml") / str(self.logger.experiment_id) / "generated"
if not saving_path.exists():
saving_path.mkdir(parents=True, exist_ok=True)
for idx, image in enumerate(self.test_step_outputs):
self.logger.experiment.log_image(
self.logger.run_id,
image.squeeze().cpu().numpy(),
f"generated_{idx}.png",
)
self.validation_step_outputs.clear()
def configure_optimizers(self):
return self.optimizer
In this refactored code, the compute_loss
method is responsible for computing the loss. This method is called in the training_step
, validation_step
, and test_step
methods to ensure that all loss computations are done outside of the model [1][2].
To continue talking to Dosu, mention @dosu.
The classifier is computing the loss within the lightning module while the generator is not.
Because a model can be used for different purposes it makes sense to do it in the lightning module