Closed sachinravi14 closed 6 years ago
@sachinravi14 How are you running the train script while getting the import paths to work? I can do it by manually setting PYTHONPATH
, but it might be good to have a documented way to run train.py
.
@sagelywizard Ran it with:
python -um train.train --data=config/lyrics.yaml --model=config/unigram.yaml --task=config/5shot.yaml --checkpt_dir=[insert path]
Does this work for you? I guess we should add this to documentation in README.
Currently, the code will crash when you do multiple runs in a row. It'll crash on https://github.com/AI-ON/Few-Shot-Music-Generation/pull/12/files#diff-5b8530a56f7576949b498f530b79df0cR105 because it already exists. It'd be nice if it automatically saved samples from different runs in different directories. Thoughts?
For multiple runs, we would also need to checkpoint separate model snapshots right? Currently, on a new run, wouldn't it just overwrite the model snapshot files of previous run?
That's true. The problem I was referring to was multiple runs of different models. So, for example, if you run the unigram model, it'll save the checkpoints in checkpoint_dir/unigram_model
and the samples in checkpoint_dir/samples
. If I run the same thing again with the LSTM model, it'll save it in checkpoint_dir/basic_tf_model
. When it tries to save the samples, it'll break since checkpoint_dir/samples
already exists.
So, the issue is that if you don't change the checkpoint directory between runs, it'll crash at the end and you'll lose the samples. It'd be nice to change the semantics somehow that it didn't crash at the end. Maybe it could throw an error at the beginning if the checkpoint directory already existed? Or treat the checkpoint directory as a directory name prefix, e.g. checkpoint_dir_0/samples
, checkpoint_dir_1/samples
, etc. Or something else along those lines?
For what it's worth, these are the numbers I was getting:
Validation Avg NLL: 6.106e+00
Test Avg NLL: 6.053e+00
Validation Avg NLL: 6.128e+00
Test Avg NLL: 6.087e+00
Validation Avg NLL: 7.008e+00
Test Avg NLL: 6.943e+00
Validation Avg NLL: 6.911e+00
Test Avg NLL: 6.865e+00
@sachinravi14 will be good if in the train folder you added the documentation on how to reproduce your results Also, would you please add the requirements.txt file to install the needed dependencies ?
@sagelywizard For the multiple runs, I imagined that for run i
for unigram model, for example, the argument would be checkpt_dir=checkpt_dir/unigram_model_run_i/
, meaning that each run of the model would have its own directory where its snapshots and samples would be. What do you think about that?
@sachinravi14 Sounds reasonable to me, but it'd probably be good to add a note to the README that you shouldn't do multiple runs with the same checkpt_dir
.
@sachinravi14 I had the same issue with the checkpoint dirs and I realized a bit late when the runs where already corrupt so I had to reconfigure and relaunch. I agree with @sagelywizard might be good either to add a change automatically (example, append _MODEL_ID to the checkpt_dir ) or add a note in the README file.
MIDI Test samples do not have enough data to be played (and all contain the same information). Lyrics Test samples all contain the same information
All of the MIDI generated samples I got are 41 bytes long, Same happens for all the text samples, all are 99 bytes long
Lyrics data in the model samples: "I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I"
MIDI data in the model samples (HEX/TXT view): 00000000: 4d54 6864 0000 0006 0001 0001 00dc 4d54 MThd..........MT 00000010: 726b 0000 0013 00ff 5103 07a1 2000 ff58 rk......Q... ..X 00000020: 0404 0218 0801 ff2f 000a ......./..
I did a simple test incrementing the size of the sample (train.py line 109) to 10 times the max len (500 elements) and the data obtained is the following (with the unigram model only) Which seems to be only repeating the first element
"I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I"
Here a list (note _midi folders are from unigram checkpoints, _midi_2 folders are from lstm generation)
-rw-rw-r-- 1 user user 41 mai 19 11:38 ./checkpoints_midi_2/samples/sample_0/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_10/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_11/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_12/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_13/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_14/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_15/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_16/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_17/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_18/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_19/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:38 ./checkpoints_midi_2/samples/sample_1/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:38 ./checkpoints_midi_2/samples/sample_2/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:38 ./checkpoints_midi_2/samples/sample_3/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_4/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_5/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_6/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_7/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_8/model_sample.mid -rw-rw-r-- 1 user user 41 mai 19 11:39 ./checkpoints_midi_2/samples/sample_9/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_0/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_10/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_11/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_12/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_13/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_14/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_15/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_16/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_17/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_18/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_19/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_1/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_2/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_3/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_4/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_5/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_6/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_7/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_8/model_sample.mid -rw-rw-r-- 1 user user 41 mai 18 20:08 ./checkpoints_midi/samples/sample_9/model_sample.mid
@leomrocha It is expected for the unigram model that the samples all have the same token because currently sample
function for unigram model currently is just repeatedly picking the token with the highest probability so for example I
is the most common word in the lyrics data so it is constantly being picked. It is more of question why this is also occurring for the LSTM model because it can model longer context. But based on the performance, LSTM is performing similar to unigram model so it doesn't seem to be doing more than modeling word counts. The main question is whether the bad performance of LSTM is due to some bug or it is hard to learn to model the longer context when the songs keep changing for each artist.
@sachinravi14 Thx for the explanation.
These commits add a training script, an abstract model API, and Tensorflow implementations of a lstm and unigram baselines based on the API.
The training script trains a given model for a set number of iterations, evaluates it on validation and test sets, and saves a set number of sampled songs from the trained model.
Currently, the performance of the models (based on average negative log-likelihood) on sampled episodes from the test set on the lyrics and midi datasets are the following:
Lyric:
Midi:
It is suspicious that the LSTM model is doing worse/is very close to the Unigram model, which I'm looking into.