AlexMili / torch-dataframe

Utility class to manipulate dataset from CSV file
MIT License
67 stars 8 forks source link

Add compatibility to make integration into Torchnet simpler #24

Closed gforge closed 8 years ago

gforge commented 8 years ago

Facebook's Torchnet (just released) has it's own implemented dataset solution. It lacks a csv-interface, handling of categories, and core statistics. It is from what I understand mostly an approach to sampling that requires functions to have two functions implemented:

This would require changing the Dataframe:size() function in the develop branch to return only number rows instead of both rows and columns. The get() is synonymous with get_row() if I understand this correctly:

In torchnet, a sample returned by dataset:get() is supposed to be a Lua table. Fields of the table can be arbitrary, even though many datasets will only work with torch tensors.

The latter sentence suggests that changing the internal storage (issue #16) may be wise for optimal integration.

gforge commented 8 years ago

The update is closing in a mutch more mature state. I'm struggling with how to approach the parallel dataset iterator. In the mnist example it is rather straight forward:

-- function that sets of dataset iterator:
local function getIterator(mode)
   return tnt.ParallelDatasetIterator{
      nthread = 1,
      init    = function() require 'torchnet' end,
      closure = function()

         -- load MNIST dataset:
         local mnist = require 'mnist'
         local dataset = mnist[mode .. 'dataset']()
         dataset.data = dataset.data:reshape(dataset.data:size(1),
            dataset.data:size(2) * dataset.data:size(3)):double()

         -- return batches of data:
         return tnt.BatchDataset{
            batchsize = 128,
            dataset = tnt.ListDataset{  -- replace this by your own dataset
               list = torch.range(1, dataset.data:size(1)):long(),
               load = function(idx)
                  return {
                     input  = dataset.data[idx],
                     target = torch.LongTensor{dataset.label[idx] + 1},
                  }  -- sample contains input and target
               end,
            }
         }
      end,
   }
end

The problem here is that they use the mnist package for loading the actual data into each thread and then pass a global index to the batch dataset. The sampler functions that we use from the Twitter dataset keep an internal index. We could modify so that:

  1. the enque function calls get_batch before the threads:addjob(
  2. the batch is then serialized and the batch frame replaces the idx in arglist
  3. replace the get method with the to_tensor (alt. override the Batchframe inherited get with a specific get method for the Batchframe
  4. modify how the transform/filter work and add the target_/input_transform. The latter two also need to be exported into the thread environment