mayhemsloth / Drum-Tabber

Automatic drum tab project using self-created data set and TensorFlow NN
6 stars 1 forks source link

Implement and verify weighted cross entropy loss in compute loss function #21

Closed mayhemsloth closed 3 years ago

mayhemsloth commented 3 years ago

Due to the sparsity of positive labels in all classes compared to the amount of slices of spectrograms being produced, a weighted cross entropy loss should probably be used to rectify this. Another way to unbalance your dataset is to multiply the positive labels so they are as balanced as negative labels, but for various reasons I don't want to do this. The main reason is that once I implement the recurrent CNN layers then I will not want to be randomly choosing different sections and copying them into the dataset. I want to keep everything in time order.

In order to do this we need to have the total number of labels for each class in the entire dataset (including the validation set) before doing any computing of losses. This information will be used to produce the weights in the weighted cross entropy loss.

This issue is resolved when code has been written and verified to work properly with weighted cross entropy loss in the compute loss function.

mayhemsloth commented 3 years ago

After computing the frequency of all the classes (with FullSet only, as with most of the rest of the code base, without access to the FullSet then it does not work yet), I put the 1/frequency list into weights of the weighted cross entropy loss and the model seemed to learn to put 1s into all the classes. The ones that are the "most obvious" classes, bass drum and snare drum events, were able to be learned by a simple model with my current dataset, so I think I can verify that the weighted loss is doing its job. Before this implementation, the model loss would drift to ln(2) and would never guess 1s.