Open lostella opened 4 years ago
We could have both a parameter and a context manager, as in mxnet context.
We also need to extend the pickling methods to support pytorch, essentially this:
# retrieve the batch from shared memory along with metadata
success, worker_id, batch = pickle.loads(got)
needs to work for pytorch (tensors) too.
We also need to extend the pickling methods to support pytorch, essentially this:
# retrieve the batch from shared memory along with metadata success, worker_id, batch = pickle.loads(got)
needs to work for pytorch (tensors) too.
We need to be fully serde
compatible, not just pickle.
Couldn't we just use type-specific strategies for this? Basically define a Batcher
type, which then is implemented for different types?
Couldn't we just use type-specific strategies for this? Basically define a
Batcher
type, which then is implemented for different types?
Yeah, something like that as well
thanks for looking into this... on the pytorch side though it would be ideal to continue to use torch DataLoader
which then return the batches from some map or Iterable style Dataset
. The DataLoader
also supports workers etc. and other libraries like pytorch-Lightning then work with this abstraction. On the tensorflow side I would need to check what tf-data
expects or returns...
thanks for looking into this... on the pytorch side though it would be ideal to continue to use torch
DataLoader
which then return the batches from some map or Iterable styleDataset
.
Yes, it would be nice to have that too.
What I’m proposing here is to maintain gluonts’ data loader as a first step (this should really be a no-brainer). But I agree that eventually making it easy to use each framework’s dataset abstractions and data loading tools would be very nice.
The very last step of the data loader is to stack a sequence of dictionaries item-wise to get a single dictionary (a batch).
The way things are stacked depends on their type, and currently np.ndarray objects are converted into mx.nd.NDArray ones and stacked along a new first axis, see here.
However, the batchify function could somehow be parametrized by the type of array that it should produce, and yield batches containing PyTorch or TensorFlow arrays instead, if desired. This rather minimal change would enable using the datasets, data transformations, and data loader from gluonts to feed models written using frameworks other than mxnet.
One way to do that would be passing a parameter down the data loading stack (some string?); a less invasive option could be that of using a context manager.