cc-ai / climategan

Code and pre-trained model for the algorithm generating visualisations of 3 climate change related events: floods, wildfires and smog.
https://thisclimatedoesnotexist.com
GNU General Public License v3.0
72 stars 18 forks source link

[WIP] MOAPR #171

Closed vict0rsch closed 3 years ago

vict0rsch commented 3 years ago

(Mother Of All PRs)

Introduction

A shit load of changes, sorry for that.

Initially I just wanted to include AMP=Automatic Mixed Precision, meaning enabling a training procedure in 16 bits to save memory and speed at test time (easy at test-time only but people argue if you're going to test at 16b then train at 16b so here we are).

Spoiler: AMP does not work yet. Any help is most welcome. See this PyTorch Discussion: https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372/8

WIP: something is currently very wrong with the masker which produces weird zebra stripes. I'm looking into it.

Changes

Refactor

Masker

I wanted to clean-up our update_g and update_d trainer methods. So I moved all the logic to

All these methods include a for_ argument as they are used for bot the G and D updates. The logic is only slightly more complicated but at least it is all unified in the same place instead of being scattered across the files, hundreds of lines appart. It's a tradeoff. Differences between the G and D updates are less obvious, but differences between tasks are cleaner and task updates are unified.

I also cleaned up naming schemes to use simpler variable names target task domain instead of update_target update_task batch_domain as I think it's clear enough and it makes the code more readable.

Potential source of errors: amp does not handle the BCELoss because of stability issues in low precision, so I switched to BCELossWithLogits. This means the m decoder's output is not normalized anymore, so you need to apply a sigmoid to get a [0-1] mask! Please check that this is appropriately done in the code.

AdventAdversarialLoss was borked due to re-normalizing already normalized inputs. Check that for s and m we have the appropriate inputs.

Painter

The z is None logic has moved into trainer.sample_painter_z(batch_size). It's less obvious it can be None, it's more robust as one does not need to remember check for no_z.

There's a new loss pl4m = painter_loss_for_masker which freezes the weights of the painter and propagates the loss from the painter's discriminator to the masker. Check it out, and comment on it. Pay a very special attention to the freezing procedure: is it the right one? What about .detach() or torch.zero_grad() ? Do you think they are equivalent? Do you think what I did here is enough? Please research a little it is important.

Logger

moved logging logic to a Logger class to lighten up the trainer's code

Additions

deeplabv3

there's a new deeplabv3.py file. Segmentation decoders were renamed into DeepLabV2Decoder and DeepLabV3Decoder. To match the original implementation, if using DLV3 the encoder is a backbone (resnet or mobilenet) returning z as a tuple (high-level features, low-level features). The s head uses both, the other heads only use the first. Maybe they should also somehow use both too. (@alexrey88 think about that for DADA)

load_opts checks for res_dim and the architectures which should match:

It also checks that encoder and segmentation decoder architectures match, though it's not mandatory

test_trainer.py

A new, simplified test you should always run. Just python test_trainer.py will work. It has a small batch_size (2) and a small output size (images are 256) so it should not take more than 5 mins on a GPU (tested on a titanxp you don't need a large GPU so no excuse). It still features the cool colors and auto-delete of experiments and output_dirs

defaults.yaml

mobilenet backbone is the new encoder default as it has 40 times less params for a theoretical comparable performance.

now 2 new mask images are logged: (strictly) binary masks for mask > 0.5 and mask > 0.1 to have an idea of the distribution of values

now: [input, input & float mask, input & binary mask (0.5), input & binary mask (0.1), target, float mask]

2020-11-09 at 12 42

vict0rsch commented 3 years ago

@alexrey88 @melisandeteng I updated the pr to include your feedback + added a check for loss weights: no need to compute losses for weights which are 0

melisandeteng commented 3 years ago

I'm just doing a last check run, please wait before merging :)

vict0rsch commented 3 years ago

Good point @melisandeteng I agree. next time around :)

vict0rsch commented 3 years ago

@melisandeteng I agree the code I copied from stackoverflow is not perfect but it does not explain -1s :(