jeshraghian / snntorch

Deep and online learning with spiking neural networks in Python
https://snntorch.readthedocs.io/en/latest/
MIT License
1.17k stars 202 forks source link

"Expected all tensors to be on the same device" When Creating LIF Neurons #316

Closed kt-13 closed 2 months ago

kt-13 commented 2 months ago

Description

I get an error telling me that all tensors must be on the same device when I try to create a new model on a GPU. Below is the code I am using. It seems to be a similar issue to the one here https://github.com/jeshraghian/snntorch/issues/225. If you manually set the device for each leaky object, like I did in the commented out lines, it fixes the issue.

What I Did

class VGG7_SNN(nn.Module):
    def __init__(self, beta, num_classes=10):
        super(VGG7_SNN, self).__init__()

        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(3)
        self.leaky1 = snn.Leaky(beta=beta)
        #self.leaky1.mem = self.leaky1.mem.to(torch.device("cuda"))

        self.conv2 = nn.Conv2d(3, 64, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(64)
        self.leaky2 = snn.Leaky(beta=beta)
        #self.leaky2.mem = self.leaky2.mem.to(torch.device("cuda"))
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
        self.bn3 = nn.BatchNorm2d(128)
        self.leaky3 = snn.Leaky(beta=beta)
        #self.leaky3.mem = self.leaky3.mem.to(torch.device("cuda"))

        self.conv4 = nn.Conv2d(128, 128, kernel_size=3)
        self.bn4 = nn.BatchNorm2d(128)
        self.leaky4 = snn.Leaky(beta=beta)
        #self.leaky4.mem = self.leaky4.mem.to(torch.device("cuda"))

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3)
        self.bn5 = nn.BatchNorm2d(256)
        self.leaky5 = snn.Leaky(beta=beta)
        #self.leaky5.mem = self.leaky5.mem.to(torch.device("cuda"))

        self.conv6 = nn.Conv2d(256, 256, kernel_size=3)
        self.bn6 = nn.BatchNorm2d(256)
        self.leaky6 = snn.Leaky(beta=beta)
        #self.leaky6.mem = self.leaky6.mem.to(torch.device("cuda"))

        self.conv7 = nn.Conv2d(256, 256, kernel_size=3)
        self.bn7 = nn.BatchNorm2d(256)
        self.leaky7 = snn.Leaky(beta=beta)
        #self.leaky7.mem = self.leaky7.mem.to(torch.device("cuda"))

        self.lin1 = nn.Linear(256*7*7, 4096)
        self.bn1d1 = nn.BatchNorm1d(4096)
        self.leaky14 = snn.Leaky(beta=beta)
        #self.leaky14.mem = self.leaky14.mem.to(torch.device("cuda"))
        self.do = nn.Dropout(p = 0.2)
        self.lin2 = nn.Linear(4096, 2048)
        self.bn1d2 = nn.BatchNorm1d(4096)
        self.leaky15 = snn.Leaky(beta=beta)
        #self.leaky15.mem = self.leaky15.mem.to(torch.device("cuda"))
        self.do2 = nn.Dropout(p = 0.2)
        self.lin3 = nn.Linear(2048, num_classes)

        self.leaky16 = snn.Leaky(beta=beta)
        #self.leaky16.mem = self.leaky16.mem.to(torch.device("cuda"))

    def forward(self, x, time_steps, ratio, epoch):

        self.leaky1.reset_mem()
        self.leaky2.reset_mem()
        self.leaky3.reset_mem()
        self.leaky4.reset_mem()
        self.leaky5.reset_mem()
        self.leaky6.reset_mem()
        self.leaky7.reset_mem()

        self.leaky14.reset_mem()
        self.leaky15.reset_mem()
        self.leaky16.reset_mem()

        spk_recording = []
        xOrig = x

        for t in range(time_steps):
            start1 = time.time()
            b1 = self.bn1(self.conv1(xOrig))
            l1 = self.leaky1(b1)[0]#, m1

            b2 = self.bn2(self.conv2(l1))
            l2 = self.leaky2(b2)[0]#, m2

            b3 = self.bn3(self.conv3(l2))
            l3 = self.leaky3(b3)[0] #, m3

            b4 = self.bn4(self.conv4(l3))
            l4 = self.leaky4(b4)[0]#, m4

            b5 = self.bn5(self.conv5(l4))
            l5 = self.leaky5(b5)[0]#, m5

            b6 = self.bn6(self.conv6(l5))
            l6 = self.leaky6(b6)[0]#, m6

            b7 = self.bn7(self.conv7(l6))
            p1 = self.pool3(b7)
            l7 = self.leaky7(p1)[0]#, m7

            f1 = torch.flatten(l7, 1)

            fc1 = self.lin1(f1)

            l14 = self.leaky14(fc1)[0]#, m14

            fc2 = self.lin2(l14)

            l15 = self.leaky15(fc2)[0]#, m15

            fc3 = self.lin3(l15)
            l16 = self.leaky16(fc3)[0]#, m16
            spk_recording.append(l16)
        return torch.stack(spk_recording)
batch_size = 64
num_epochs = 10
time_steps = 5
beta = 0.75
transform = torchvision.transforms.Compose([
  torchvision.transforms.RandomHorizontalFlip(p=0.3),
  torchvision.transforms.RandomVerticalFlip(p=0.3),
  torchvision.transforms.ToTensor()
  ,])
train = torchvision.datasets.FashionMNIST(root='/content', transform=transform, download=True)

#train_ds = torch.utils.data.Subset(train, torch.arange(0, 10000))
test = torchvision.datasets.FashionMNIST(root='/content', train=False, transform=torchvision.transforms.ToTensor(), download=True)
#test_ds = torch.utils.data.Subset(train, torch.arange(1000, 2000))
data_loader = DataLoader(train, batch_size=batch_size, shuffle=True)#, num_workers = 8
test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)

spk_rec_final = []
loss_hist = []
acc_hist = []

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") #torch.device("cuda") if torch.cuda.is_available() else
model = VGG16_SNN(beta, 10).to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=2.47e-4, momentum=0.9) #lr=2.5e-4 if no weight decay
loss_fn =  SF.ce_rate_loss()

for epoch in range(num_epochs):

  print(f"Starting epoch number {epoch}")
  counter = 0
  torch.cuda.empty_cache()

  for bitmap, target in iter(data_loader):

      bitmap = bitmap.to(device)
      target = target.to(device)
      model.train()

      spk_rec = model(bitmap, time_steps, ratio, epoch)
      spk_rec_final.append(spk_rec)

      loss_val = loss_fn(spk_rec, target)

      optimizer.zero_grad()
      loss_val.backward()

      optimizer.step()
      loss_hist.append(loss_val.item())

  if epoch == num_epochs -1:
      print('calcing accuracy')
      with torch.no_grad():
        model.eval()
        acc_train = batch_accuracy(data_loader, model, time_steps, device, ratio)
        acc_test = batch_accuracy(test_loader, model, time_steps, device, ratio)
        print(f"Iteration {epoch}, Train Acc: {acc_train * 100:.2f}%\n")
        print(f"Iteration {epoch}, Test Acc: {acc_test * 100:.2f}%\n")
        acc_hist.append(acc_test.item())
        break
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 54
     50 model.train()
     51 #print(torch.unsqueeze(bitmap, dim=1).shape)
---> 54 spk_rec = model(bitmap, time_steps, ratio, epoch)
     55 #print('time to get spks', time.time() - start)
     56 spk_rec_final.append(spk_rec)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

Cell In[4], line 152, in VGG16_SNN.forward(self, x, time_steps, ratio, epoch)
    149 b1 = self.bn1(self.conv1(xOrig))
    150 #print('time to convolve and normalize', time.time() -  start1)
    151 #print(len(torch.where(x == 1)[0]))
--> 152 l1 = self.leaky1(b1)[0]#, m1
    153 #print(l1.shape)
    154
    155 
    156 #print('time to run through first block', time.time() -  start1)
    157 #start = time.time()
    158 
    159 #print(len(torch.where(x == 1)[0]))

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.10/dist-packages/snntorch/_neurons/leaky.py:208, in Leaky.forward(self, input_, mem)
    205 if not self.mem.shape == input_.shape:
    206     self.mem = torch.zeros_like(input_, device=self.mem.device)
--> 208 self.reset = self.mem_reset(self.mem)
    209 self.mem = self.state_function(input_)
    211 if self.state_quant:

File /usr/local/lib/python3.10/dist-packages/snntorch/_neurons/neurons.py:105, in SpikingNeuron.mem_reset(self, mem)
    102 def mem_reset(self, mem):
    103     """Generates detached reset signal if mem > threshold.
    104     Returns reset."""
--> 105     mem_shift = mem - self.threshold
    106     reset = self.spike_grad(mem_shift).clone().detach()
    108     return reset

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
xiziqiao commented 2 months ago

I got the same issue here! I think it is a library bug My solution is to downgrade it to 0.8.0.

SSBakh07 commented 2 months ago

Apparently, this issue has come up before (#225) and a workaround was described there, but I was able to temporarily fix it by downgrading to version 0.8.1.

pip install snntorch==0.8.1
gekkom commented 2 months ago

I will take a look at this, in the meanwhile you can also fix this with.

torch.set_default_device("cuda")
morenzoe commented 2 months ago

I still find this error when running the training loop without population coding in Advanced Tutorials: Population Coding. Setting default device to cuda did not work, but downgrading to 0.8.1 did the job. I guess it's because of the deprecation of snntorch.backprop module.

jeshraghian commented 2 months ago

Have you tried installing snntorch from the source rather than pip?

morenzoe commented 2 months ago

I am trying to do it in Colab now. However another error comes out, ModuleNotFoundError: No module named 'nir', even though the module was there when I checked with !pip show. Does the setup.py in snnTorch only installing the module locally in the snnTorch folder path? Sorry for asking out of topic, some help will be much appreciated!

jeshraghian commented 2 months ago

Ah I run into the same error, but it fixed when I restarted my run time... in any case, I'll update the pypi today or tomorrow. That'll hopefully fix everything.

morenzoe commented 2 months ago

I am finally able to run both of the tutorial in Colab by installing and importing nir and nirtorch first before installing snntorch from the source. Nevertheless, updating the pypi will be a great help. Thank you!