google / uis-rnn

This is the library for the Unbounded Interleaved-State Recurrent Neural Network (UIS-RNN) algorithm, corresponding to the paper Fully Supervised Speaker Diarization.
https://arxiv.org/abs/1810.04719
Apache License 2.0
1.55k stars 320 forks source link

Question about custom data generator #75

Open YanaHontarenko opened 4 years ago

YanaHontarenko commented 4 years ago

I see that previously you answered that "for big amount of data you can fit model several times"(#8). But I didn't work with pytorch before and don't know how it is must work: how pass info about losses and gradients for different parts of dataset. That's why I want to ask if your library has ability to fit with custom data generator (like fit_generator in keras). Or maybe you can tell me where I can see example for such case.

This is what my class for data looks like(prevoiusly I save different parts of data in "data.npz"):

from torch.utils.data import Dataset
class Data(Dataset):
    def __init__(self, set, batch_size=32, shuffle=True):
        self.data = np.load("data.npz", allow_pickle=True)
        self.set = set
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.indexes) / self.batch_size))

    def __getitem__(self, index):
        temp_indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        sequences_batch, clusters_batch = self.__data_generation(temp_indexes)

        return sequences_batch, clusters_batch

    def on_epoch_end(self):
        self.indexes = np.arange(self.data[f'{self.set}_sequence'].shape[0])
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, temp_indexes):
        sequences = self.data[f'{self.set}_sequence'][temp_indexes]
        clusters = self.data[f'{self.set}_cluster_id'][temp_indexes]
        sequences = [seq.astype(float) + 0.00001 for seq in sequences]
        clusters = [np.array(cid).astype(str) for cid in clusters]

        return sequences, clusters

And this is how I create generator:

train_set = Data("train", 32, True)
train_generator = DataLoader(train_set)

P.S.: I'll be happy to receive any help, because I don't even sure that I go in the right direction.