mlangguth89 / downscaling_benchmark

6 stars 0 forks source link

Implement Harris et al WGAN #11

Open seblehner opened 5 months ago

seblehner commented 5 months ago

Feature issue/branch for the implementation of the Harris et al WGAN into the benchmark framework.

paulaharder commented 4 months ago

Here is a nice, cleaned-up repo for this work: https://github.com/bobbyantonio/downscaling-cgan

seblehner commented 1 month ago

@mlangguth89 To summarise, the current state:

Hope it doesn't consume too much time to iron out the loss.

mlangguth89 commented 1 month ago

Hi @seblehner a few preliminary updates: I've found that the critic-losses have been quickly diverging and exploding (c_loss attains values of in the range of -O(10^6) after100 iterations whereas the cg_loss ended up at +O(10^6). This could be largely alleviated by replacing the softplus-activation in the final output-layer of the generator. Using the softplus-layer is a proper choce to ensure positive definite outputs for (transformed) precipitation data. For normalized quantities that may attain pos. and neg. values however, this is a bad choice. Instead, the generator now applies a linear activation instead, which seems to fix the issue.
Additionally, I fixed the issue with the gradient penalty which required adaptions to properly run the discrimnator model on a mixture between generated and ground truth data (cf. Eq. 5 in Harris et al., 2022).

mlangguth89 commented 4 weeks ago

Several fixes to the data pipeline have been realised, however the code currently still fails since shape_in is so far neither deduced from the data pipeline-objects nor it's parsed correctly when instantiating the model-class. Regarding the latter, the HarrisWGAN-class currently expects a dictionary where all the input- and output-shapes of the dscriminator and generator are available. However, this is redundant, since the shape of the (coarse) input data together with the downscaling factor and the number of noise-channels is suifficient to infer all dimensions. Thus, the original strategy to parse the input shape of the data as a list can still be realized. The former however requires further adaption/handling in the prepare_dataset-method.
In summary, the following tasks are left open for now:

mlangguth89 commented 3 weeks ago

Training now works technically. However, the training is very slow (expected due to the model size) and the training dynamics are not good yet. With a batch size of 4, more than 40K training steps are required per epoch and thus one epoch requires about one day on Juwels Cluster. Larger mini-batch sizes do not fit nto memory, at least on the V100 nodes, but it may work on Juwels Booster with its A100 nodes.Eventually, distributed training to increase the effective batch-size should be considered.

However, the training dynamics require stabilization first. After about 1K training steps, the critic loss for the generator becomes strongly negative which results into exploding gradients at later training steps. Several options to stabilize training are possible which are in summary:

paulaharder commented 1 week ago

Here some loss curves from my run a the moment, maybe helpful to compare image image

paulaharder commented 1 week ago
mlangguth89 commented 1 week ago

I've performed several experiments with variations of the batch size and the learning rate. On a single A100 GPU, a batch size of 8 is possible, whereas a batch size of 16 is also possible when the channels in the first conv-layer of the critic/discriminator (N_filters disc) is reduced from 512 to 128 (cf. @paulaharder setting). In general, larger batch sizes and smaller learning rates with $l{r,gen} > l{r,cri}$ (similar to the Sha WGAN config, but contrairly to the Paula's config with $l{r,gen} < l{r,cri}$) seem to stabilize the trainig dynamics. Both result in smoother weight updates of the networks and it could also be that a batch size of 64 (with data-distributed training) would enable a learning rate configuration as in Paula's experiments. At least, I've noticed that a bacth size of 16 results in better training dynamics compared to a batch size of 8 in the respective experiments.
A more comprehensive documentation is available in this hedgedoc.

I've also submitted 'long-term' experiments with a batch size of 8 and 16 to figure out if training is ultimatively stabilized. However, the jobs are currently queued.

mlangguth89 commented 2 days ago

It is found that the training can be stabilized with the following settings:

Furthermore, fine-grained learning rate scheduling is probably beneficial. This can be achieved by reducing the steps_per_epoch-parameter of the fit-method. This also enables more frequent checkpointing since setting save_freq to an integer value in the ModelCheckpoint-callback did not work as expected. Future trainings of the Harris WGAN set steps_to_epoch=4012 which correpsonds to one third of the actual steps_per_epoch with a batch size of 16.