yizhilll / MERT

Official implementation of the paper "Acoustic Music Understanding Model with Large-Scale Self-supervised Training".
Apache License 2.0
301 stars 18 forks source link

High vram usage for long audio. #6

Open hykilpikonna opened 1 year ago

hykilpikonna commented 1 year ago

The model seems to process the entire audio at once, which leads to high vram usage for long audio. I was trying to compute MERT on a 9:58 audio with an A100 80GB GPU, and it tried to allocate 90GB of vram.

image

Is it possible to split the audio first, process each segment and obtain the same results? I tried to split the audio into 60s windows using the code below. Even though I managed to make the segmented embedding into the same shape, it seems to give a large mean square error from the original calculation if the entire audio is passed in at once.

window_length = int(self.sr * 60) # 60 seconds
overlap_length = int(self.sr * 4.987) # 4.987 seconds (5s window - 1 * 75Hz framerate)
overlap_frames = int(4.987 * 75) - 1 # 75 Hz frame rate
embeddings = []

print("Audio shape:", audio.shape)
print("Window length:", window_length)
print("Overlap length:", overlap_length)
print("Overlap frames:", overlap_frames)

# Iterate over audio with overlap
for start in range(0, audio.shape[0], window_length - overlap_length):
    end = start + window_length
    segment = audio[start:end]
    print("Segment:", segment.shape)
    # if len(segment) < window_length:
    #     break

    # Process each segment
    inputs = self.processor(segment, sampling_rate=self.sr, return_tensors="pt").to(self.device)
    with torch.no_grad():
        out = self.model(**inputs, output_hidden_states=True)
        out = torch.stack(out.hidden_states).squeeze() # [13 layers, timeframes, 768]
        out = out[11] # [timeframes, 768]

        print("Frames before:", out.shape[0])

        # Remove overlap from the end of the segment
        if end < audio.shape[0]:
            out = out[:-overlap_frames, :]

        print("Frames after:", out.shape[0])

    embeddings.append(out)

# Stack embeddings for all segments
out = torch.cat(embeddings, dim=0)

return out
hykilpikonna commented 1 year ago

Here is the absolute error between the original and the segmented calculations for a 4-minute audio on a graph... it's weird that the overlapping areas are not the only thing that is affected, but the error seems to bleed to the entire rest of the segment.

image