magenta / magenta-js

Magenta.js: Music and Art Generation with Machine Learning in the browser
Apache License 2.0
1.96k stars 312 forks source link

Generating midi in Magenta-js with own trained DrumsRNN model #599

Open jonasdoevenspeck opened 2 years ago

jonasdoevenspeck commented 2 years ago


I want to train my own Drums_RNN model and use it to generate MIDI in a web-browser with magenta-js. I have the midi generation and synthesis working with the pretrained drums_rnn model but have some issues when using my own trained model.

1) First, I trained the model as follows:

drums_rnn_train --config="drum_kit" --run_dir="path1" --sequence_example_file="path2" --hparams="batch_size=64,rnn_layer_sizes=[64,64]" --num_training_steps=3000

The model is then used in magenta-js as follows:

  import { Player, MusicRNN } from '@magenta/music';

  const model = new MusicRNN("local_http_link");
  const player = new Player();
  await model.initialize();
  await player.initialize();

//manually generate priming sequence here
sequence = {{{pitch: 36, quantizedStartStep: 0, quantizedEndStep: 1, isDrum: true}.
                     {pitch: 36, quantizedStartStep: 4, quantizedEndStep: 5, isDrum: true},
                     {pitch: 36, quantizedStartStep: 8, quantizedEndStep: 9, isDrum: true},
      quantizationInfo: {stepsPerQuarter: 4},
      tempos: [{time: 0, qpm: 120}],
      totalQuantizedSteps: 9}

//generate new sequence by feeding priming sequence
  const samples = await model.continueSequence(sequence, 50);
  await player.start(samples);

This results in the following error as decribed already in

Unhandled Rejection (TypeError): Cannot read properties of undefined (reading 'matMul')

2) As suggested in the issue linked above, I solved this by including "--hparams=attn_length=0"

So the new train cmd is: drums_rnn_train --config="drum_kit" --run_dir="path1" --sequence_example_file="path2" --hparams="batch_size=64,rnn_layer_sizes=[64,64],attn_length=0" --num_training_steps=3000

I used the same magenta-js code to generate new MIDI and now get the following error:

Unhandled Rejection (Error): Error in matMul: inner shapes (74) and (582) of Tensors with shapes 1,74 and 582,256 and transposeA=false and transposeB=false must match.

I also add the entire error trace at the bottom of this issue. I have the feeling the input dimension of my priming sequence is incompatible with the dimension of the trained network. However, I have no idea how to solve it.

Extra information: 1) I'm able to generate new midi sequences with this model with the CLI with the following cmd: drums_rnn_generate --config="drum_kit" --run_dir="model_path" --hparams="batch_size=64,rnn_layer_sizes=[64,64],attn_length=0" --output_dir="path_2"

However, this cmd doesn't require a priming sequence in contrary to the model.continueSequence() from magenta-js.

2) When I use the pretrained DrumsRNN model to generate midi in magenta-js, I don't have any errors.

Full error trace: Unhandled Rejection (Error): Error in matMul: inner shapes (74) and (582) of Tensors with shapes 1,74 and 582,256 and transposeA=false and transposeB=false must match. Module.assert src/util_base.ts:108 105 | assert(a != null, () => The input to the tensor constructor must be a non-null value.); 106 | } 107 | // NOTE: We explicitly type out what T extends instead of any so that

108 | // util.flatten on a nested array of number doesn't try to infer T as a 109 | // number[][], causing us to explicitly type util.flatten(). 110 | / 111 | Flattens an arbitrarily nested array. View compiled batchMatMul [as kernelFunc] src/kernels/BatchMatMul.ts:64 61 | [a3dStrides[0], 1, a3dStrides[1]] : 62 | [a3dStrides[0], a3dStrides[1], 1]; 63 | const [bInnerStep, bOuterStep, bBatch] = transposeB ? 64 | [1, b3dStrides[1], b3dStrides[0]] : | ^ 65 | [b3dStrides[1], 1, b3dStrides[0]]; 66 | const size = leftDim rightDim; 67 | const result = buffer([batchDim, leftDim, rightDim], a3d.dtype); View compiled kernelFunc src/engine.ts:598 595 | } 596 | const dataId = backend.write(backendVals, shape, dtype); 597 | const t = new Tensor(shape, dtype, dataId, this.nextTensorId()); 598 | this.incRef(t, backend); | ^ 599 | // Count bytes for string tensors. 600 | if (dtype === 'string') { 601 | const info = this.state.tensorInfo.get(dataId); View compiled (anonymous function) src/engine.ts:668 665 | } 666 | this.state.numTensors--; 667 | if (a.dtype === 'string') { 668 | this.state.numStringTensors--; | ^ 669 | } 670 | const info = this.state.tensorInfo.get(a.dataId); 671 | const refCount = info.refCount; View compiled Engine.scopedRun src/engine.ts:453 450 | if (this.shouldCheckForMemLeaks()) { 451 | this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos); 452 | } 453 | const outTensors = => { | ^ 454 | // todo (yassogba) remove this option (Tensor) when node backend 455 | // methods have been modularized and they all return tensorInfo. 456 | // TensorInfos do not have a rank attribute. View compiled Engine.runKernelFunc src/engine.ts:665 662 | disposeTensor(a) { 663 | if (!this.state.tensorInfo.has(a.dataId)) { 664 | return; 665 | } | ^ 666 | this.state.numTensors--; 667 | if (a.dtype === 'string') { 668 | this.state.numStringTensors--; View compiled Engine.runKernel src/engine.ts:522 519 | outputs = kernelProfile.outputs; 520 | } 521 | }); 522 | if (isTapeOn) { | ^ 523 | this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved, attrs); 524 | } 525 | if (this.state.profiling) { View compiled matMul_ src/ops/mat_mul.ts:54 matMulop src/ops/operation.ts:51 48 | ENGINE.endScope(result); 49 | return result; 50 | } 51 | catch (ex) { | ^ 52 | ENGINE.endScope(null); 53 | throw ex; 54 | } View compiled basicLSTMCell_ src/ops/basic_lstmcell.ts:61 58 | const f = slice(res, [0, sliceCols 2], sliceSize); 59 | const o = slice(res, [0, sliceCols 3], sliceSize); 60 | const newC = add(mul(sigmoid(i), tanh(j)), mul($c, sigmoid(add($forgetBias, f)))); 61 | const newH = mul(tanh(newC), sigmoid(o)); 62 | return [newC, newH]; 63 | } 64 | export const basicLSTMCell = op({ basicLSTMCell }); View compiled Module.basicLSTMCellop src/ops/operation.ts:51 48 | ENGINE.endScope(result); 49 | return result; 50 | } 51 | catch (ex) { | ^ 52 | ENGINE.endScope(null); 53 | throw ex; 54 | } View compiled Array. src/musicrnn/model.ts:165 162 | else { 163 | sampledOutput = logits.argMax().as1D(); 164 | } 165 | if (returnProbs) { | ^ 166 | probs.push(tf.softmax(logits)); 167 | } 168 | nextInput = View compiled multiRNNCell src/ops/multi_rnn_cell.ts:56 Module.multiRNNCell__op src/ops/operation.ts:51 48 | ENGINE.endScope(result); 49 | return result; 50 | } 51 | catch (ex) { | ^ 52 | ENGINE.endScope(null); 53 | throw ex; 54 | } View compiled MusicRNN.sampleRnn src/music_rnn/model.ts:375 (anonymous function) src/music_rnn/model.ts:272 (anonymous function) src/engine.ts:442 439 | // backend and set properties like this.backendName 440 | // tslint:disable-next-line: no-unused-expression 441 | this.backend; 442 | } | ^ 443 | const kernel = getKernel(kernelName, this.backendName); 444 | let out; 445 | if (kernel != null) { View compiled Engine.scopedRun src/engine.ts:453 450 | if (this.shouldCheckForMemLeaks()) { 451 | this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos); 452 | } 453 | const outTensors = => { | ^ 454 | // todo (yassogba) remove this option (Tensor) when node backend 455 | // methods have been modularized and they all return tensorInfo. 456 | // TensorInfos do not have a rank attribute. View compiled Engine.tidy src/engine.ts:440 437 | // can be deferred until an op/ kernel is run). 438 | // The below getter has side effects that will try to initialize the 439 | // backend and set properties like this.backendName 440 | // tslint:disable-next-line: no-unused-expression | ^ 441 | this.backend; 442 | } 443 | const kernel = getKernel(kernelName, this.backendName); View compiled Module.tidy src/globals.ts:192 189 | const tensors = getTensorsInContainer(container); 190 | tensors.forEach(tensor => tensor.dispose()); 191 | } 192 | / 193 | Keeps a tf.Tensor generated inside a tf.tidy from being disposed 194 | automatically. 195 | * View compiled MusicRNN.continueSequenceImpl src/music_rnn/model.ts:258 MusicRNN.continueSequence src/music_rnn/model.ts:215 playGen src/magenta/magenta.js:137 134 | 135 | 136 | 137 | const samples = await model.continueSequence(sequence, 50); | ^ 138 | 139 | player.resumeContext(); 140 | await player.start(samples);

mjxmusic commented 2 years ago

hi @jonasdoevenspeck - did you ever work through this issue?