Closed nartes closed 4 years ago
Hi,
Regarding multinode training: You are correct, each node trains only a single decoder, and the encoder is shared between all nodes. Gradients are shared with regard to the encoder, but each decoder is trained on a single node. For 6 domains, we used 6 nodes.
Regarding single node training: This is indeed a bug. A solution will be keep a list of decoders (+ a list of optimizers) and update the relevant decoder based on the domain.
Regarding weights: The bestmodel*.pth is the checkpoint that produced the lowest eval loss across all epochs per musical domain. The "lastmodel.pth" contains the same encoder across all domains.
To transfer domain a separate decoder is required per domain, as it is stated in the paper:![image](https://user-images.githubusercontent.com/902947/75977171-e1b13000-5eec-11ea-8a59-aa01a593e1a5.png)
In a distributed setup
discriminator
andencoder
are wrapped withDistributedDataParallel
, where asdecoder
has onlyDataParallel
: https://github.com/facebookresearch/music-translation/blob/fd51cbcbeb0af3de0930e79c25b539d050cc9e11/src/train.py#L168 Which results in across nodes gradient descent fordiscriminator
andencoder
, yetdecoder
is not being shared and effectively a separate domain is being trained per node.DataParallel
is being used to accelerate batch processing with multiple GPUs per node.A thing is that single node training does still rotate domain where from batches are sampled. https://github.com/facebookresearch/music-translation/blob/fd51cbcbeb0af3de0930e79c25b539d050cc9e11/src/train.py#L273 In such a case single node training does learn a hidden representation which is domain agnostic, yet synthesis will produce the very same input waveform.
What way to transfer the domain in such a case? There's no parameterized domain class inside latent representation and a single decoder is being trained.
Is there a hidden purpose for a single node training? Since it's not clear what way to make music translation with that.
Regarding pretrained weights,
encoder
parameters are different across nodes checkpoints. For some reason they were not synchronized during the training:Outputs generated by loaded encoders from
lastmodel_0.pth
,lastmodel_1.pth
, andlastmodel_2.pth
for a common inputnumpy.random.randint(0, 256, (1, 800))
are different.