lucidrains / musiclm-pytorch

Implementation of MusicLM, Google's new SOTA model for music generation using attention networks, in Pytorch
MIT License
3.15k stars 254 forks source link

I'm getting a memory error that seems unrealistic (small dataset) so I think I've messed up or there's a bug #30

Open besketh opened 1 year ago

besketh commented 1 year ago

Can you help with this when you have a moment please? I'd be much appreciative

this is the error:

Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm Community Edition 2022.3.2\plugins\python-ce\helpers\pydev\pydevd.py", line 1496, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm Community Edition 2022.3.2\plugins\python-ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "E:\ML\playingWithMusicLM\main.py", line 114, in <module>
    loss = mulan(wavsTensor, selectedTextsTensor)
  File "C:\Users\user\.conda\envs\playingWithMusicLM\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "<@beartype(musiclm_pytorch.musiclm_pytorch.MuLaN.forward) at 0x24d7011c670>", line 47, in forward
  File "C:\Users\user\.conda\envs\playingWithMusicLM\lib\site-packages\musiclm_pytorch\musiclm_pytorch.py", line 547, in forward
    audio_latents = self.get_audio_latents(wavs)
  File "C:\Users\user\.conda\envs\playingWithMusicLM\lib\site-packages\musiclm_pytorch\musiclm_pytorch.py", line 523, in get_audio_latents
    audio_embeds = self.audio(wavs)
  File "C:\Users\user\.conda\envs\playingWithMusicLM\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\user\.conda\envs\playingWithMusicLM\lib\site-packages\musiclm_pytorch\musiclm_pytorch.py", line 396, in forward
    rel_pos_bias = self.dynamic_pos_bias_mlp(rel_dist.float())
RuntimeError: [enforce fail at ..\c10\core\impl\alloc_cpu.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 2250000000 bytes.

I'm just running 10 song/text pairs as tensor params in order to "train mulan"

The code to do so is as follows:

 # get a ton of <sound, text> pairs and train
ids = []
texts = []
with open(musicDescriptiveMetadataFilename, newline='\n') as csvfile:
    rows = csv.reader(csvfile, delimiter=',')
    for rowNumber, row in enumerate(rows):
        if rowNumber > 0:
            id = row[0]
            text = row[5]
            ids.append(id)
            texts.append(text)

wavs = []
selectedTexts = []

audioFileNames = os.listdir(".//ytRips")
for n, id in enumerate(ids):
    if n < 10:
        for audioFileName in audioFileNames:
            if audioFileName.__contains__(id):
                a = read(".\\ytRips\\" + audioFileName)
                a = np.array(a[1], dtype=np.float32)
                try:
                    channels=a.shape[1]
                except:
                    channels=1
                    continue

                samples=a.shape[0]

                if channels==2:
                    a = np.resize(a, (samples,1))

                if samples == 480000:
                    wavs.append(a)
                    selectedTexts.append(numpy.asarray(stringToListOfInts(texts[n]),dtype=np.compat.long))

#resize texts to same size
resizedSelectedTexts=[]
for selectedText in selectedTexts:
    size=selectedText.shape[0]
    if size > 450:
        resizedSelectedTexts.append(numpy.resize(selectedText,(450,1)))
    else:
        tmp=selectedText
        for x in range(10):
            tmp=np.concatenate((tmp, selectedText), axis=0)
        if tmp is not None:
            resizedSelectedTexts.append(numpy.resize(np.stack(tmp, axis=0), (450,1)))

wavsTensor = torch.squeeze(torch.tensor(np.stack(wavs, axis=0),dtype=torch.float32))
selectedTextsTensor= torch.squeeze(torch.tensor(np.stack(resizedSelectedTexts, axis=0),dtype=torch.long))

loss = mulan(wavsTensor, selectedTextsTensor)
besketh commented 1 year ago

pretty sure its just that my code is trash. i tried to run before i learnt to walk. ive sinced learnt how to use pytorch and a data loader class better so ill just close this

jaideep2 commented 1 year ago

Hi @besketh did you manage to solve it. My implementation here https://github.com/lucidrains/musiclm-pytorch/issues/16#issuecomment-1470533654 has similar issues.. wondering how to fix it

besketh commented 1 year ago

hi ther @jaideep2, I actually redid the code to a standard that im happy so I'm not convinced it's my fault any more haha. so i think ill open this ticket up again. I'm still getting memory issues so let me know if you find anything that solves this

theres videos like this https://www.youtube.com/watch?v=uQx2bbRzvKI which get into advanced memory management for pytorch which was going to be my next step but I didnt have time yet to dive deeper

besketh commented 1 year ago

current implementation

import array
import torchaudio
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy
import soundfile
from scipy.io.wavfile import read
import torch
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer
import os
import pathlib
import numpy
import pandas

import nonechucks
from transformers.models import opt

class MusicCapsDataset(Dataset):
    def __init__(self):
        print("instantiating dataset")
        self.metadata = pandas.read_csv("musiccaps-public-cleaned2.csv")[0:1]
        self.audioDir = "E:\\ML\\ytRips"
        # self.dropMetadataForInvalidAudioFiles()

    def __len__(self):
        return (len(self.metadata))

    def dropMetadataForInvalidAudioFiles(self):
        print("finding invalid audiofiles, dropping datapoints")
        droppedIndexes = []
        for n in range((self.metadata.shape[0])):
            shape = self.getAudioFileShape(n)
            if not self.isValidAudio(shape):
                droppedIndexes.append(n)

        for droppedIndex in droppedIndexes:
            print("dropping " + str(droppedIndex))
            self.metadata.drop(axis=0, index=droppedIndex, inplace=True)

        self.metadata.to_csv("musiccaps-public-cleaned2.csv")

        print(f"Dropped {str(len(droppedIndexes))} items")

    def getAudioFileShape(self, index):
        id = self.getIdForIndex(index)
        path = self.getAudioFilePath(id)
        if not path:
            return None
        return soundfile.read(path)[0].shape

    def isValidAudio(self, shape):
        exists = False
        fortyEightK = False

        if shape is not None:
            exists = True
            if shape[0] == 480000:
                fortyEightK = True

        return (exists and fortyEightK)

    def getIdForIndex(self, index):
        return self.metadata.iloc[index, 1]

    def getIndexForId(self, id):
        return self.getMetaData(id).index.item()

    def __getitem__(self, index):
        id = self.getIdForIndex(index)
        audio_path = self.getAudioFilePath(id)
        description = self.forceResizeTexts(self.getDescription(id), 450)

        return audio_path, description

    def getTensorForAudioPath(self, audio_path):
        signal, sr = torchaudio.load(audio_path)
        if signal.shape[0] == 2:
            return self.enforceShape(signal[0])
        else:
            return torch.squeeze(signal)

    def stringToListOfInts(self, s):
        a = []
        for char in list(s):
            a.append(ord(char))
        return a

    def getTensorForString(self, string):
        l = self.stringToListOfInts(string)
        a = numpy.asarray(l, dtype=numpy.compat.long)
        return torch.tensor(a, dtype=torch.long)

    def getAudioFilePath(self, id):
        audioFileNames = os.listdir(self.audioDir)
        for audioFileName in audioFileNames:
            if audioFileName.__contains__(str(id)):
                return os.path.join(self.audioDir, audioFileName)
        return None

    def getMetaData(self, id):
        query = 'ytid == "' + str(id) + '"'
        return self.metadata.query(query)

    def getDescription(self, id):
        index = self.getIndexForId(id)
        return self.metadata.iloc[index, 6]

    def forceResizeTexts(self, text, intendedSize):
        size = len(text)
        if size >= intendedSize:
            text = text[:intendedSize]
            assert len(text) == intendedSize
            return text
        else:
            diff = intendedSize - size
            repeat = int(float(size) / float(diff) + float(3))
            for i in range(repeat):
                text = text + text
            text = text[:intendedSize]
            assert len(text) == intendedSize
            return text

    def enforceShape(self, audioTensor):
        if audioTensor.shape[0] != 480000:
            return audioTensor[1]
        else:
            return audioTensor

if __name__ == "__main__":

    # importing the module
    import tracemalloc

    # starting the monitoring
    tracemalloc.start()

    # displaying the memory
    print("setting up mulan with audio and text transformer "+ str(tracemalloc.get_traced_memory()))

    audio_transformer = AudioSpectrogramTransformer(
        dim=512,
        depth=6,
        heads=8,
        dim_head=64,
        spec_n_fft=128,
        spec_win_length=24,
        spec_aug_stretch_factor=0.8
    )

    text_transformer = TextTransformer(
        dim=512,
        depth=6,
        heads=8,
        dim_head=64
    )

    mulan = MuLaN(
        audio_transformer=audio_transformer,
        text_transformer=text_transformer
    )

    print("instantiating data: " + str(tracemalloc.get_traced_memory()))

    musicCapsDataset = MusicCapsDataset()
    EPOCHS = 10
    BATCH_SIZE = 1
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'

    print("creating data loader: " + str(tracemalloc.get_traced_memory()))
    trainDataLoader = DataLoader(musicCapsDataset, batch_size=BATCH_SIZE, num_workers=0)

    def trainOneEpoch():
        print("starting epoch: " + str(tracemalloc.get_traced_memory()))
        count = 0
        for audio_paths, descriptions in trainDataLoader:

            audioTensors = []
            descriptionTensors = []
            print("creating audio tensors: " + str(tracemalloc.get_traced_memory()))

            for audio_path in audio_paths:
                audioTensors.append(musicCapsDataset.getTensorForAudioPath(audio_path))

            print("creating description tensors: " + str(tracemalloc.get_traced_memory()))
            for description in descriptions:
                descriptionTensors.append(
                    musicCapsDataset.getTensorForString(musicCapsDataset.forceResizeTexts(description, 450)))

            print("stacking: " + str(tracemalloc.get_traced_memory()))
            audioTensor = torch.stack(audioTensors)
            descriptionTensor = torch.stack(descriptionTensors)

            print("mulan loss function: " + str(tracemalloc.get_traced_memory()))
            loss = mulan(audioTensor, descriptionTensor)

            print("backwards propogation: " + str(tracemalloc.get_traced_memory()))
            loss.backward()

    def train():
        for epoch in range(EPOCHS):
            print("epoch " + str(epoch))
            trainOneEpoch()

    train()
    # stopping the library
    tracemalloc.stop()
Mingxiangyu commented 1 year ago
F:\PyCharm 2020.3.3\plugins\python\helpers\pydev\_pydevd_bundle\pydevd_utils.py:605: FutureWarning: iteritems is deprecated and will be removed in a future version. Use .items instead.
  for item in s.iteritems():
Traceback (most recent call last):
  File "D:\Anaconda\envs\musiclm-pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 628, in __next__
    data = self._next_data()
  File "D:\Anaconda\envs\musiclm-pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 671, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "D:\Anaconda\envs\musiclm-pytorch\lib\site-packages\torch\utils\data\_utils\fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "D:\Anaconda\envs\musiclm-pytorch\lib\site-packages\torch\utils\data\_utils\fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "E:/WorkSpace/pyWorkSpace/musiclm-pytorch/issue1.py", line 73, in __getitem__
    description = self.forceResizeTexts(self.getDescription(id), 450)
  File "E:/WorkSpace/pyWorkSpace/musiclm-pytorch/issue1.py", line 107, in getDescription
    index = self.getIndexForId(id)
  File "E:/WorkSpace/pyWorkSpace/musiclm-pytorch/issue1.py", line 68, in getIndexForId
    return self.getMetaData(id).index.item()
  File "D:\Anaconda\envs\musiclm-pytorch\lib\site-packages\pandas\core\base.py", line 347, in item
    raise ValueError("can only convert an array of size 1 to a Python scalar")
ValueError: can only convert an array of size 1 to a Python scalar
python-BaseException

进程已结束,退出代码为 -1

Hello, there was an error reported during the training phase while executing this code @besketh

Newbas commented 1 year ago

Hey @besketh, thanks for sharing your code. Have you resolved memory consumption or is it because of the size of the model? I tried to cut length of the input and it fits, also was checking memory consumption on 120000 frames (2.5sec audio 48k) it consumes around 15GB memory with only one audio file. So curious is it because of the size of the model? And saw that batch size increase consume but not 1:1, goes slower (at least in my experiments). So wondering what setup needed to train on Musicap if everything correct here. Also it is just Mulan step.