magenta / magenta-js

Magenta.js: Music and Art Generation with Machine Learning in the browser
https://magenta.tensorflow.org
Apache License 2.0
1.96k stars 311 forks source link

Add compatibility for attention applied only at single layer of MusicRNN. #106

Open naotokui opened 6 years ago

naotokui commented 6 years ago

Hi,

I've trained a DrumsRNN model with my own drum sequence dataset and been trying to use it with magenta-js. When I load the model, I get errors apparently caused by the differences in the layer structure described in weights_manifest.json. I suspect that it is a compatibility issue.

Which tensorflow version is compatible with magenta-js?

FYI, I used tensorflow 1.4.1 to train the DrumsRNN model.

Thanks

adarob commented 6 years ago

What is the error that you saw?

naotokui commented 6 years ago

I have a CodePen sketch based on Tero's amazing Neural Drum Machine.

Every time I tried to generate sequences (sampleRnn()) using my own model, I got this error:

magentamusic.js:45133 Uncaught (in promise) TypeError: Cannot read property 'matMul' of undefined
    at t.sampleRnn (magentamusic.js:45133)
    at magentamusic.js:45098
    at Object.t.tidy (magentamusic.js:12849)
    at t.<anonymous> (magentamusic.js:45088)
    at magentamusic.js:44957
    at Object.next (magentamusic.js:44938)
    at magentamusic.js:44932
    at new Promise (<anonymous>)
    at r (magentamusic.js:44928)
    at t.continueSequence (magentamusic.js:45069)

Then I figured out that my instance of musicRNN class has 0 lstmCells after the initialization (len(this.lstmCells) == 0). I believe it's due to the differences in the network structure described in weights_manifes.json.

Here is the weights_manifes.json of the original DrumsRNN model. https://gist.github.com/naotokui/847394f037c3481c27600da13a51728a

and this is mine. https://gist.github.com/naotokui/aa64abf09e748c0d5dbc14d7009731d4

For example, one of the lstm layers in the original model has a name like this: rnn/attention_cell_wrapper/multi_rnn_cell/cell_0/basic_lstm_cell/bias

but in my model, the same layer iss named: rnn/multi_rnn_cell/cell_0/attention_cell_wrapper/basic_lstm_cell/bias

I suspect this is the cause of this issue.

adarob commented 6 years ago

Thanks for pointing this out. There is indeed a compatibility issue due to a change we made in the Python code. Can you share your checkpoint with me so that I can make the required changes to magenta.js to support both types of checkpoints?

On Tue, Jul 17, 2018, 8:09 PM Nao Tokui notifications@github.com wrote:

I have a CodePen sketch based on Tero's amazing Neural Drum Machine.

Every time I tried to generate sequences (sampleRnn()) using my own model, I got this error: magentamusic.js:45133 Uncaught (in promise) TypeError: Cannot read property 'matMul' of undefined at t.sampleRnn (VM510 magentamusic.min.js:7) at VM510 magentamusic.min.js:7 at Object.t.tidy (VM510 magentamusic.min.js:7) at t. (VM510 magentamusic.min.js:7) at VM510 magentamusic.min.js:7 at Object.next (VM510 magentamusic.min.js:7) at VM510 magentamusic.min.js:7 at new Promise () at r (VM510 magentamusic.min.js:7) at t.continueSequence (VM510 magentamusic.min.js:7) t.sampleRnn @ magentamusic.js:45133 (anonymous) @ magentamusic.js:45098 t.tidy @ magentamusic.js:12849 (anonymous) @ magentamusic.js:45088 (anonymous) @ magentamusic.js:44957 (anonymous) @ magentamusic.js:44938 (anonymous) @ magentamusic.js:44932 r @ magentamusic.js:44928 t.continueSequence @ magentamusic.js:45069 generatePattern @ pen.js:94 regenerate @ pen.js:261 (anonymous) @ pen.js:485 Promise.then (async) (anonymous) @ pen.js:485

Then I figured out that my instance of musicRNN class has 0 lstmCells after the initialization (len(this.lstmCells) == 0). I believe it's due to the differences in the network structure described in weights_manifes.json.

Here is the weights_manifes.json of the original DrumsRNN model. https://gist.github.com/naotokui/847394f037c3481c27600da13a51728a

and this is mine. https://gist.github.com/naotokui/aa64abf09e748c0d5dbc14d7009731d4

For example, one of the lstm layers in the original model has a name like this: rnn/attention_cell_wrapper/multi_rnn_cell/cell_0/basic_lstm_cell/bias

but in my model, the same layer iss named: rnn/multi_rnn_cell/cell_0/attention_cell_wrapper/basic_lstm_cell/bias

I suspect this is the origin of this issue.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/tensorflow/magenta-js/issues/106#issuecomment-405794917, or mute the thread https://github.com/notifications/unsubscribe-auth/ABCa6C3d3VtvkKbh_U0qE2iGEcJepDvuks5uHqb9gaJpZM4VMtv- .

naotokui commented 6 years ago

Thank you! here are my checkpoints: https://www.dropbox.com/sh/84ohnr9ee8yvzvm/AADn70s7HGzcLhQkLAfEaRbna?dl=0

adarob commented 6 years ago

Hi Nao,

Unfortunately I won't be able to look into this until early next week. If you want something that works before then, it should be possible to retrain your model with "--hparams=attn_length=0". The resulting checkpoint should then work, although it may not be perform quite as well without attention.

-Adam

On Thu, Jul 19, 2018 at 7:39 AM Nao Tokui notifications@github.com wrote:

Thank you! here are my checkpoints: https://www.dropbox.com/sh/84ohnr9ee8yvzvm/AADn70s7HGzcLhQkLAfEaRbna?dl=0

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/tensorflow/magenta-js/issues/106#issuecomment-406299982, or mute the thread https://github.com/notifications/unsubscribe-auth/ABCa6M6GuqQrMsqqxQrD5VLriaUTclCTks5uIJo8gaJpZM4VMtv- .

mjxmusic commented 2 years ago

hi @adarob - any idea if there was ever a fix implemented for this w/ Magenta.JS?

adarob commented 2 years ago

Sorry. but I am not sure if/when we will have time to fix this on our end. However, we would absolutely accept PR hhat does!