lava-nc / lava

A Software Framework for Neuromorphic Computing
https://lava-nc.org
Other
554 stars 144 forks source link

Using Pytorch code in dataloader causes io.dataloader.SpikeDataloader to stall #540

Open nskat opened 1 year ago

nskat commented 1 year ago

Objective of issue: Get clarity regarding an issue when performing pre-processing of data with Pytorch and using the I/O processes.

Lava version:

I'm submitting a ...

Current behavior:

I perform classification on large, static matrices (size 381x651). Due to their size, I have programmed a custom dataloader in which the data is imported from .mat files using scipy.io, transformed into tensors, maxpooled for dimensionality reduction, flattened, and ultimately rate encoded. The resulting tensor is transformed back into a numpy array, fed to an io.dataloader.SpikeDataloader, a embedded_io.spike.PyToNxAdapter, and to my network. I have realised that the passage from numpy array to torch tensor caused the code to stall when running on Loihi. Below is some code to reproduce the behavior.

Expected behavior:

Steps to reproduce:

Related code:

num_examples = 10  # dummy value
ds = 64  # pooling factor
T = 10  # number of time-steps

class BlockingDummyDataset(torch.utils.data.Dataset):
    """
    Generates random data and pre-processes them using pytorch
    """
    def __init__(self, shape,
                 num_examples,
                 num_timesteps,
                 transform=ToTensor(),
                 ds=ds
                 ):
        super(BlockingDummyDataset, self).__init__()

        self.data = np.random.random((num_examples, ) + shape)
        self.targets = np.random.randint(0, 9, (num_examples,))
        self.transform = transform
        self.ds = ds
        self.num_timesteps = num_timesteps

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index: int):
        """
        :param: index: int
         Index
        :return: tuple: (data, target) where target is index of the target class.
        """
        data = self.data[index]
        data = self.transform(data) # applies ToTensor(), blocking line
        data = torch.bernoulli(torch.nn.functional.max_pool2d(data, (self.ds, self.ds)).flatten().unsqueeze(1).repeat(1, num_timesteps))  # Max pooling + resizing + rate encoding

        return data.numpy(), self.targets[index]

test_data = BlockingDummyDataset(shape=(381, 651), num_examples=num_examples, num_timesteps=T)
dataloader = io.dataloader.SpikeDataloader(dataset=test_data, interval=1)
pytoncadapter = PyToNxAdapter(shape=((381 // ds) * (651 // ds),), interval=1)

# Network
net = netx.hdf5.Network(net_config='network.hdf5') # example network with dummy weights

# Connect everything
dataloader.out_ports.s_out.connect(pytoncadapter.inp)
pytoncadapter.out.connect(net.inp)

num_steps = T * num_examples
cfg = Loihi2HwCfg()

net.run(condition=RunSteps(num_steps=num_steps),
        run_cfg=cfg)
net.stop()

Other information:

Thanks for having a look at this!

PhilippPlank commented 1 year ago

Sorry for the late reply. @bamsumit Is this a bug on our side or on the pytorch side?