microsoft / unilm

Large-scale Self-supervised Pre-training Across Tasks, Languages, and Modalities
https://aka.ms/GeneralAI
MIT License
19.64k stars 2.51k forks source link

wavlm extract features failed in loop #596

Closed Ramlinbird closed 1 year ago

Ramlinbird commented 2 years ago

Describe the bug Model I am using WavLM, follow the scripts in the wavlm's README.txt, (I modified to gpu device, I have tried cpu mode, the same memory exception.)

import torch
from WavLM import WavLM, WavLMConfig
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# load the pre-trained checkpoints
checkpoint = torch.load('/path/to/wavlm.pt')
cfg = WavLMConfig(checkpoint['cfg'])
model = WavLM(cfg).to(device)
model.load_state_dict(checkpoint['model'])
model.eval()

# extract the representation of last layer
wav_input_16khz = torch.randn(1,10000).to(device)
rep = model.extract_features(wav_input_16khz)[0]

works well, but while I run it in a loop, it raised memory error.

......
loop = []
for i in range(10):
    rep = model.extract_features(wav_input_16khz)[0]
    loop.append(rep)

image

mjhydri commented 2 years ago

I don't think it is a bug related to the model. Most of your GPU memory is already used and by appending all of the representations in a single list, it fails to allocate the memory.

My suggested solutions: 1- dump each output representation in your hard dist rather than appending them together in a list. 2- if you are using the large model, switch it to the base or base+ model. The less number of parameters leads to less memory occupation. 3- Switching the processor to CPU will make the inference slower but may fix the error.

RobertBoganKang commented 2 years ago

I tried a lot, and find the answer. Reference from https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification. We should add with torch.no_grad(): when extracting features.

For example:

with torch.no_grad():
    feature = <<extract_feature_model>>(audio)

In your case, fix it with:

......
loop = []
for i in range(10):
    with torch.no_grad()
         rep = model.extract_features(wav_input_16khz)[0]
    loop.append(rep)

hope this could work.

Ramlinbird commented 1 year ago

RobertBoganKang's answer seems to work. Thanks.