google / uis-rnn

This is the library for the Unbounded Interleaved-State Recurrent Neural Network (UIS-RNN) algorithm, corresponding to the paper Fully Supervised Speaker Diarization.
https://arxiv.org/abs/1810.04719
Apache License 2.0
1.55k stars 320 forks source link

Large datasets cause training machine to run out of memory #8

Closed xinli94 closed 5 years ago

xinli94 commented 5 years ago

Hi,

I am working on training a uis-rnn model with dataset voxceleb2: http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html.

But loading npz files as training data causes out of memory issues, which took down my machine. There are over 1,000,000 training clips in the dataset. Is it possible to make large datasets work with this api?

Thanks, Xin

wq2012 commented 5 years ago

Currently we do not have good support for training the model on large datasets.

For now, please just call the fit() function multiple times on different shards of your data. Do not try to train on all data at once.

We will be working on some improvement over the fit() function later to better support this behavior.

Also, do not try to save all data into one single npz file... Just save different shards into different files, and load one each time.

xinli94 commented 5 years ago

Currently we do not have good support for training the model on large datasets.

For now, please just call the fit() function multiple times on different shards of your data. Do not try to train on all data at once.

We will be working on some improvement over the fit() function later to better support this behavior.

Also, do not try to save all data into one single npz file... Just save different shards into different files, and load one each time.

Thanks!

wq2012 commented 5 years ago

@xinli94

We just committed some fixes to make sure the training works better when calling fit() multiple times.

We also added some suggestions to README.md: https://github.com/google/uis-rnn#training-on-large-datasets

fanlu commented 5 years ago

@wq2012 Why didn't you use torch.nn.dataparallel to support large dataset and multi gpu cards training?