Official implementation of the paper "Acoustic Music Understanding Model with Large-Scale Self-supervised Training".
High vram usage for long audio. #6

Open hykilpikonna opened 1 year ago

hykilpikonna commented 1 year ago

The model seems to process the entire audio at once, which leads to high vram usage for long audio. I was trying to compute MERT on a 9:58 audio with an A100 80GB GPU, and it tried to allocate 90GB of vram.


Is it possible to split the audio first, process each segment and obtain the same results? I tried to split the audio into 60s windows using the code below. Even though I managed to make the segmented embedding into the same shape, it seems to give a large mean square error from the original calculation if the entire audio is passed in at once.

window_length = int(self.sr * 60) # 60 seconds
overlap_length = int(self.sr * 4.987) # 4.987 seconds (5s window - 1 * 75Hz framerate)
overlap_frames = int(4.987 * 75) - 1 # 75 Hz frame rate
embeddings = []

print("Audio shape:", audio.shape)
print("Window length:", window_length)
print("Overlap length:", overlap_length)
print("Overlap frames:", overlap_frames)

# Iterate over audio with overlap
for start in range(0, audio.shape[0], window_length - overlap_length):
    end = start + window_length
    segment = audio[start:end]
    print("Segment:", segment.shape)
    # if len(segment) < window_length:
    #     break

    # Process each segment
    inputs = self.processor(segment, sampling_rate=self.sr, return_tensors="pt").to(self.device)
    with torch.no_grad():
        out = self.model(**inputs, output_hidden_states=True)
        out = torch.stack(out.hidden_states).squeeze() # [13 layers, timeframes, 768]
        out = out[11] # [timeframes, 768]

        print("Frames before:", out.shape[0])

        # Remove overlap from the end of the segment
        if end < audio.shape[0]:
            out = out[:-overlap_frames, :]

        print("Frames after:", out.shape[0])


# Stack embeddings for all segments
out = torch.cat(embeddings, dim=0)

return out
hykilpikonna commented 1 year ago

Here is the absolute error between the original and the segmented calculations for a 4-minute audio on a graph... it's weird that the overlapping areas are not the only thing that is affected, but the error seems to bleed to the entire rest of the segment.
