Closed vict0rsch closed 3 years ago
@alexrey88 @melisandeteng I updated the pr to include your feedback + added a check for loss weights: no need to compute losses for weights which are 0
I'm just doing a last check run, please wait before merging :)
Good point @melisandeteng I agree. next time around :)
@melisandeteng I agree the code I copied from stackoverflow is not perfect but it does not explain -1s :(
(Mother Of All PRs)
Introduction
A shit load of changes, sorry for that.
Initially I just wanted to include AMP=Automatic Mixed Precision, meaning enabling a training procedure in 16 bits to save memory and speed at test time (easy at test-time only but people argue if you're going to test at 16b then train at 16b so here we are).
Spoiler: AMP does not work yet. Any help is most welcome. See this PyTorch Discussion: https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372/8
WIP: something is currently very wrong with the masker which produces weird zebra stripes. I'm looking into it.
Changes
Refactor
Masker
I wanted to clean-up our
update_g
andupdate_d
trainer methods. So I moved all the logic tomasker_m_loss
which is the loss from masking to the whole Masker generatormasker_s_loss
which is the loss from segmentation to the Maskermasker_d_loss
which is the loss from depth to the Maskermasker_c_loss
which is the loss from latent classification to the Masker (even if we never use it, I though why not?)All these methods include a
for_
argument as they are used for bot the G and D updates. The logic is only slightly more complicated but at least it is all unified in the same place instead of being scattered across the files, hundreds of lines appart. It's a tradeoff. Differences between the G and D updates are less obvious, but differences between tasks are cleaner and task updates are unified.I also cleaned up naming schemes to use simpler variable names
target
task
domain
instead ofupdate_target
update_task
batch_domain
as I think it's clear enough and it makes the code more readable.Potential source of errors:
amp
does not handle theBCELoss
because of stability issues in low precision, so I switched toBCELossWithLogits
. This means them
decoder's output is not normalized anymore, so you need to apply a sigmoid to get a [0-1] mask! Please check that this is appropriately done in the code.AdventAdversarialLoss
was borked due to re-normalizing already normalized inputs. Check that fors
andm
we have the appropriate inputs.Painter
The
z is None
logic has moved intotrainer.sample_painter_z(batch_size)
. It's less obvious it can beNone
, it's more robust as one does not need to remember check forno_z
.There's a new loss
pl4m
=painter_loss_for_masker
which freezes the weights of the painter and propagates the loss from the painter's discriminator to the masker. Check it out, and comment on it. Pay a very special attention to the freezing procedure: is it the right one? What about.detach()
ortorch.zero_grad()
? Do you think they are equivalent? Do you think what I did here is enough? Please research a little it is important.Logger
moved logging logic to a
Logger
class to lighten up the trainer's codeAdditions
deeplabv3
there's a new
deeplabv3.py
file. Segmentation decoders were renamed intoDeepLabV2Decoder
andDeepLabV3Decoder
. To match the original implementation, if using DLV3 the encoder is abackbone
(resnet
ormobilenet
) returningz
as a tuple (high-level features, low-level features). Thes
head uses both, the other heads only use the first. Maybe they should also somehow use both too. (@alexrey88 think about that for DADA)load_opts
checks forres_dim
and the architectures which should match:It also checks that encoder and segmentation decoder architectures match, though it's not mandatory
test_trainer.py
A new, simplified test you should always run. Just
python test_trainer.py
will work. It has a small batch_size (2) and a small output size (images are 256) so it should not take more than 5 mins on a GPU (tested on a titanxp you don't need a large GPU so no excuse). It still features the cool colors and auto-delete of experiments and output_dirsdefaults.yaml
mobilenet
backbone is the new encoder default as it has 40 times less params for a theoretical comparable performance.now 2 new mask images are logged: (strictly) binary masks for mask > 0.5 and mask > 0.1 to have an idea of the distribution of values
now: [input, input & float mask, input & binary mask (0.5), input & binary mask (0.1), target, float mask]