magenta / mt3

MT3: Multi-Task Multitrack Music Transcription
Apache License 2.0
1.41k stars 185 forks source link

Bugfix in event decoder (?) #69

Open sapieneptus opened 2 years ago

sapieneptus commented 2 years ago

Background

While running this decoder on a sequence of tokens (midifile -> NoteSequence -> event_tokens), I realised that after every note is processed and we return to processing shifts, the cur_time is reset (since start_time is 0). This then causes an exception to be thrown from note_sequences#decode_note_event:

if time < state.current_time:
    raise ValueError('event time < current time, %f < %f' % (
        time, state.current_time))

I believe the cur_time should be an offset of state.current_time, not of start_time - this allows the decoding sequence to pick up where it left off along the timeline.


Code example:

subsequences = note_seq.split_note_sequence(ns, 1)
    event_batches = []
    for i, subseq in enumerate(subsequences):
        subseq = note_seq.apply_sustain_control_changes(subseq)
        midi_times, midi_events = midi.note_sequence_to_events(subseq)
        del subseq.control_changes[:]

        events, _, _, _, _ = midi.encode_midi_events(audio_times, midi_times, midi_events)
        event_batches.append(events)

    reconstructed = midi.event_batches_to_note_sequence(event_batches, codec=utils.CODEC)

    midi.note_sequence_to_midi_file(reconstructed, 'moo.mid')

# midi.py

def midi_file_to_note_sequence(midi_path) -> note_seq.NoteSequence:
    """
    Convert a midi file to a list of onset and offset times and pitches
    """
    print(f"Converting midi file to note sequence: {midi_path}")
    ns = note_seq.midi_file_to_note_sequence(midi_path)
    return ns

def note_sequence_to_events(ns: note_seq.NoteSequence) -> Tuple[Sequence[float], Sequence[note_sequences.NoteEventData]]:
    return note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns)

def event_batches_to_note_sequence(event_batches, codec: event_codec.Codec=utils.CODEC) -> note_seq.NoteSequence:
    print("converting event batches to note sequence")
    decoding_state = note_sequences.NoteDecodingState()
    total_invalid_ids = 0
    total_dropped_events = 0

    for events in event_batches:
        invalid_ids, dropped_events = run_length_encoding.decode_events(
            state=decoding_state,
            tokens=events,
            start_time=decoding_state.current_time,
            max_time=None,
            codec=codec,
            decode_event_fn=note_sequences.decode_note_event
        )
        total_invalid_ids += invalid_ids
        total_dropped_events += dropped_events

    ns = note_sequences.flush_note_decoding_state(decoding_state)

    print(f'Dropped {total_dropped_events} events')
    print(f'Invalid ids: {total_invalid_ids}')
    return ns

def note_sequence_to_midi_file(ns: note_seq.NoteSequence, midi_path: str):
    """
    Convert a list of onset and offset times and pitches to a midi file
    """
    print(f"Converting events to midi file: {midi_path}")

    return note_seq.midi_io.note_sequence_to_midi_file(ns, midi_path)

def encode_midi_events(
    audio_frame_times: Sequence[float],
    midi_event_times: Sequence[float],
    midi_event_values: Sequence[note_sequences.NoteEventData]
) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int], Sequence[int]]:

    events, event_start_indices, event_end_indices, state_events, state_event_indices = run_length_encoding.encode_and_index_events(
        state=note_sequences.NoteEncodingState(),
        event_times=midi_event_times,
        event_values=midi_event_values,
        encode_event_fn=note_sequences.note_event_data_to_events,
        codec=utils.CODEC,
        frame_times=audio_frame_times,
        encoding_state_to_events_fn=note_sequences.note_encoding_state_to_events
    )
    return events, event_start_indices, event_end_indices, state_events, state_event_indices
google-cla[bot] commented 2 years ago

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

sapieneptus commented 2 years ago

I've signed the CLA, can that build be rerun?

cghawthorne commented 2 years ago

Can you provide a little more information on how you're using the decoder? I'm guessing it's for some kind of custom setup.

The way we're using it (as illustrated in metrics_utils.event_predictions_to_ns), the ground truth for the current time offset comes from the start_time passed into decode_events. That's then used to set state.current_time. We do this because we're decoding independently-inferred chunks of the full sequence.

sapieneptus commented 2 years ago

We do this because we're decoding independently-inferred chunks of the full sequence.

Right, I had suspected as such. I was just trying to understand how the MT3 + note-seq libraries work and wrote some code to split a midi file into subsequences just to see if I could then reconstruct the original midi file. So my input would be a sequence of events corresponding to the entire midi file (several minutes worth of events).

I can see how this function would work as-is for a small slice containing only a single note event + some shift events, but it would fail if it encounters multiple note events (unless there are increasingly more shifts between subsequent note events).

So perhaps it's by design - but I believe this change is still an improvement, as it should not change functionality for the 'small slice' use-case and will prevent errors in a longer slice use-case.

I have updated the description with my relevant code.

cghawthorne commented 2 years ago

I think it still doesn't work for our case because state.current_time at the end of one chunk isn't necessarily the right start time of the subsequence chunk. For example, what if there are several chunks in a row with no note events? The start time of the chunk needs to come from some external source.