Open besketh opened 1 year ago
pretty sure its just that my code is trash. i tried to run before i learnt to walk. ive sinced learnt how to use pytorch and a data loader class better so ill just close this
Hi @besketh did you manage to solve it. My implementation here https://github.com/lucidrains/musiclm-pytorch/issues/16#issuecomment-1470533654 has similar issues.. wondering how to fix it
hi ther @jaideep2, I actually redid the code to a standard that im happy so I'm not convinced it's my fault any more haha. so i think ill open this ticket up again. I'm still getting memory issues so let me know if you find anything that solves this
theres videos like this https://www.youtube.com/watch?v=uQx2bbRzvKI which get into advanced memory management for pytorch which was going to be my next step but I didnt have time yet to dive deeper
current implementation
import array
import torchaudio
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy
import soundfile
from scipy.io.wavfile import read
import torch
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer
import os
import pathlib
import numpy
import pandas
import nonechucks
from transformers.models import opt
class MusicCapsDataset(Dataset):
def __init__(self):
print("instantiating dataset")
self.metadata = pandas.read_csv("musiccaps-public-cleaned2.csv")[0:1]
self.audioDir = "E:\\ML\\ytRips"
# self.dropMetadataForInvalidAudioFiles()
def __len__(self):
return (len(self.metadata))
def dropMetadataForInvalidAudioFiles(self):
print("finding invalid audiofiles, dropping datapoints")
droppedIndexes = []
for n in range((self.metadata.shape[0])):
shape = self.getAudioFileShape(n)
if not self.isValidAudio(shape):
droppedIndexes.append(n)
for droppedIndex in droppedIndexes:
print("dropping " + str(droppedIndex))
self.metadata.drop(axis=0, index=droppedIndex, inplace=True)
self.metadata.to_csv("musiccaps-public-cleaned2.csv")
print(f"Dropped {str(len(droppedIndexes))} items")
def getAudioFileShape(self, index):
id = self.getIdForIndex(index)
path = self.getAudioFilePath(id)
if not path:
return None
return soundfile.read(path)[0].shape
def isValidAudio(self, shape):
exists = False
fortyEightK = False
if shape is not None:
exists = True
if shape[0] == 480000:
fortyEightK = True
return (exists and fortyEightK)
def getIdForIndex(self, index):
return self.metadata.iloc[index, 1]
def getIndexForId(self, id):
return self.getMetaData(id).index.item()
def __getitem__(self, index):
id = self.getIdForIndex(index)
audio_path = self.getAudioFilePath(id)
description = self.forceResizeTexts(self.getDescription(id), 450)
return audio_path, description
def getTensorForAudioPath(self, audio_path):
signal, sr = torchaudio.load(audio_path)
if signal.shape[0] == 2:
return self.enforceShape(signal[0])
else:
return torch.squeeze(signal)
def stringToListOfInts(self, s):
a = []
for char in list(s):
a.append(ord(char))
return a
def getTensorForString(self, string):
l = self.stringToListOfInts(string)
a = numpy.asarray(l, dtype=numpy.compat.long)
return torch.tensor(a, dtype=torch.long)
def getAudioFilePath(self, id):
audioFileNames = os.listdir(self.audioDir)
for audioFileName in audioFileNames:
if audioFileName.__contains__(str(id)):
return os.path.join(self.audioDir, audioFileName)
return None
def getMetaData(self, id):
query = 'ytid == "' + str(id) + '"'
return self.metadata.query(query)
def getDescription(self, id):
index = self.getIndexForId(id)
return self.metadata.iloc[index, 6]
def forceResizeTexts(self, text, intendedSize):
size = len(text)
if size >= intendedSize:
text = text[:intendedSize]
assert len(text) == intendedSize
return text
else:
diff = intendedSize - size
repeat = int(float(size) / float(diff) + float(3))
for i in range(repeat):
text = text + text
text = text[:intendedSize]
assert len(text) == intendedSize
return text
def enforceShape(self, audioTensor):
if audioTensor.shape[0] != 480000:
return audioTensor[1]
else:
return audioTensor
if __name__ == "__main__":
# importing the module
import tracemalloc
# starting the monitoring
tracemalloc.start()
# displaying the memory
print("setting up mulan with audio and text transformer "+ str(tracemalloc.get_traced_memory()))
audio_transformer = AudioSpectrogramTransformer(
dim=512,
depth=6,
heads=8,
dim_head=64,
spec_n_fft=128,
spec_win_length=24,
spec_aug_stretch_factor=0.8
)
text_transformer = TextTransformer(
dim=512,
depth=6,
heads=8,
dim_head=64
)
mulan = MuLaN(
audio_transformer=audio_transformer,
text_transformer=text_transformer
)
print("instantiating data: " + str(tracemalloc.get_traced_memory()))
musicCapsDataset = MusicCapsDataset()
EPOCHS = 10
BATCH_SIZE = 1
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
print("creating data loader: " + str(tracemalloc.get_traced_memory()))
trainDataLoader = DataLoader(musicCapsDataset, batch_size=BATCH_SIZE, num_workers=0)
def trainOneEpoch():
print("starting epoch: " + str(tracemalloc.get_traced_memory()))
count = 0
for audio_paths, descriptions in trainDataLoader:
audioTensors = []
descriptionTensors = []
print("creating audio tensors: " + str(tracemalloc.get_traced_memory()))
for audio_path in audio_paths:
audioTensors.append(musicCapsDataset.getTensorForAudioPath(audio_path))
print("creating description tensors: " + str(tracemalloc.get_traced_memory()))
for description in descriptions:
descriptionTensors.append(
musicCapsDataset.getTensorForString(musicCapsDataset.forceResizeTexts(description, 450)))
print("stacking: " + str(tracemalloc.get_traced_memory()))
audioTensor = torch.stack(audioTensors)
descriptionTensor = torch.stack(descriptionTensors)
print("mulan loss function: " + str(tracemalloc.get_traced_memory()))
loss = mulan(audioTensor, descriptionTensor)
print("backwards propogation: " + str(tracemalloc.get_traced_memory()))
loss.backward()
def train():
for epoch in range(EPOCHS):
print("epoch " + str(epoch))
trainOneEpoch()
train()
# stopping the library
tracemalloc.stop()
F:\PyCharm 2020.3.3\plugins\python\helpers\pydev\_pydevd_bundle\pydevd_utils.py:605: FutureWarning: iteritems is deprecated and will be removed in a future version. Use .items instead.
for item in s.iteritems():
Traceback (most recent call last):
File "D:\Anaconda\envs\musiclm-pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 628, in __next__
data = self._next_data()
File "D:\Anaconda\envs\musiclm-pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 671, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "D:\Anaconda\envs\musiclm-pytorch\lib\site-packages\torch\utils\data\_utils\fetch.py", line 58, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "D:\Anaconda\envs\musiclm-pytorch\lib\site-packages\torch\utils\data\_utils\fetch.py", line 58, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "E:/WorkSpace/pyWorkSpace/musiclm-pytorch/issue1.py", line 73, in __getitem__
description = self.forceResizeTexts(self.getDescription(id), 450)
File "E:/WorkSpace/pyWorkSpace/musiclm-pytorch/issue1.py", line 107, in getDescription
index = self.getIndexForId(id)
File "E:/WorkSpace/pyWorkSpace/musiclm-pytorch/issue1.py", line 68, in getIndexForId
return self.getMetaData(id).index.item()
File "D:\Anaconda\envs\musiclm-pytorch\lib\site-packages\pandas\core\base.py", line 347, in item
raise ValueError("can only convert an array of size 1 to a Python scalar")
ValueError: can only convert an array of size 1 to a Python scalar
python-BaseException
进程已结束,退出代码为 -1
Hello, there was an error reported during the training phase while executing this code @besketh
Hey @besketh, thanks for sharing your code. Have you resolved memory consumption or is it because of the size of the model? I tried to cut length of the input and it fits, also was checking memory consumption on 120000 frames (2.5sec audio 48k) it consumes around 15GB memory with only one audio file. So curious is it because of the size of the model? And saw that batch size increase consume but not 1:1, goes slower (at least in my experiments). So wondering what setup needed to train on Musicap if everything correct here. Also it is just Mulan step.
Can you help with this when you have a moment please? I'd be much appreciative
this is the error:
I'm just running 10 song/text pairs as tensor params in order to "train mulan"
The code to do so is as follows: