rdevon / cortex

A machine learning library for PyTorch
BSD 3-Clause "New" or "Revised" License
92 stars 25 forks source link

Pass through dataloader class in DatasetPlugin #186

Closed dmitriy-serdyuk closed 6 years ago

dmitriy-serdyuk commented 6 years ago

I want to be able to assign my own data loader. For example, it is useful when working with sequences and a custom collate function is needed.

Now, it is possible to create a dataset plugin like this:

class MyDataset(DatasetPlugin):
    def handle(self, ...):
        ...
        self.add_dataset(...)
        self.set_dataloader_class(my_fancy_dataloader)
rdevon commented 6 years ago

This is great! Does this work as expected?

dmitriy-serdyuk commented 6 years ago

I confirm, it works as expected

dmitriy-serdyuk commented 6 years ago

An example on how I intend to use it

I have a custom pytorch dataset WSJDataset which outputs a spectrogram for requested item. The problem it that the spectrogram dimension is different for each item. Usually I would use a collate function to aggregate a batch (create a single tensor). But cortex at the moment doesn't allow to pass arguments to the dataloader.

This change will allow this as well as using any kind of custom dataloader (if it quacks like a dataloader).

class WSJ(DatasetPlugin):
    sources = ['WSJ']

    def handle(self, source, copy_to_local=False, normalize=True,
               tanh_normalization=False, **transform_args):
        train_set = WSJDataset()
        test_set = WSJDataset()

        def collate(batch):
            collated_batch = ...
            return collated_batch

        WSJDataLoader = partial(DataLoader, collate_fn=collate)
        self.add_dataset('train', train_set)
        self.set_dataloader_class(dataloader_class=WSJDataLoader)
        self.add_dataset('test', test_set)
        self.set_input_names(['images'])
soroushmehr commented 6 years ago

That's great but I am not sure how it would interact with get_dim function:https://github.com/rdevon/cortex/blob/master/cortex/plugins.py#L314

That would be awesome if this function would work depending on the length of the sequences on that mini-batch.

dmitriy-serdyuk commented 6 years ago

@soroushmehr how does get_dim work?

soroushmehr commented 6 years ago

@rdevon correct me if I am wrong, but when you are registering or loading the dataset via cortex you define the dimension with set_dim. This works fine for any data with a fixed structure, though that's not the case for sequences/graphs/etc.

rdevon commented 6 years ago

So get_dim and set_dim are all static properties of the dataset set by the user for the dataset, so this shouldn't be a problem

soroushmehr commented 6 years ago

Which means that the sequence length or any dimension that changes from iteration to another should be extracted from the actual mini-batch data.

dmitriy-serdyuk commented 6 years ago

We can assign None for variable size channels.