ashleve / lightning-hydra-template

PyTorch Lightning + Hydra. A very user-friendly template for ML experimentation. ⚡🔥⚡
4.27k stars 654 forks source link

Adding several loss (criteria) #609

Open amirshamaei opened 1 year ago

amirshamaei commented 1 year ago

Hi, I am working on adopting H+PL template for MRI reconstruction. I need to have several loss functions. What is the best routine for adding them to a PL module?

Currently, I added this piece of code to init(). However, I believe it can be more efficient and general.

        criterion_1 : Callable = torch.nn.MSELoss,
        criterion_2:  Callable = None,
        criterion_3: Callable = None,
amirshamaei commented 1 year ago

I made it a bit better as follows:

criterions: List = [torch.nn.MSELoss],

and config file:

criterions:
  - _target_: torch.nn.MSELoss
  - _target_: torch.nn.MSELoss

and the training step:

for criterion in self.criterions:
    loss += criterion(output_imgs, target_img)

I appreciate if you suggest a better approach