Closed dmitriy-serdyuk closed 6 years ago
This is great! Does this work as expected?
I confirm, it works as expected
I have a custom pytorch dataset WSJDataset
which outputs a spectrogram for requested item. The problem it that the spectrogram dimension is different for each item. Usually I would use a collate function to aggregate a batch (create a single tensor). But cortex at the moment doesn't allow to pass arguments to the dataloader.
This change will allow this as well as using any kind of custom dataloader (if it quacks like a dataloader).
class WSJ(DatasetPlugin):
sources = ['WSJ']
def handle(self, source, copy_to_local=False, normalize=True,
tanh_normalization=False, **transform_args):
train_set = WSJDataset()
test_set = WSJDataset()
def collate(batch):
collated_batch = ...
return collated_batch
WSJDataLoader = partial(DataLoader, collate_fn=collate)
self.add_dataset('train', train_set)
self.set_dataloader_class(dataloader_class=WSJDataLoader)
self.add_dataset('test', test_set)
self.set_input_names(['images'])
That's great but I am not sure how it would interact with get_dim
function:https://github.com/rdevon/cortex/blob/master/cortex/plugins.py#L314
That would be awesome if this function would work depending on the length of the sequences on that mini-batch.
@soroushmehr how does get_dim
work?
@rdevon correct me if I am wrong, but when you are registering or loading the dataset via cortex you define the dimension with set_dim
. This works fine for any data with a fixed structure, though that's not the case for sequences/graphs/etc.
So get_dim
and set_dim
are all static properties of the dataset set by the user for the dataset, so this shouldn't be a problem
Which means that the sequence length or any dimension that changes from iteration to another should be extracted from the actual mini-batch data.
We can assign None
for variable size channels.
I want to be able to assign my own data loader. For example, it is useful when working with sequences and a custom collate function is needed.
Now, it is possible to create a dataset plugin like this: