jeshraghian / snntorch

Deep and online learning with spiking neural networks in Python
https://snntorch.readthedocs.io/en/latest/
MIT License
1.31k stars 219 forks source link

snntorch multi GPU training issue #154

Closed rkdgmlqja closed 1 year ago

rkdgmlqja commented 1 year ago

Fri Dec 2 11:16:53 2022 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 470.82.01 Driver Version: 470.82.01 CUDA Version: 11.4 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA RTX A6000 On | 00000000:1B:00.0 Off | Off | | 30% 30C P8 29W / 300W | 1MiB / 48682MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ | 1 NVIDIA RTX A6000 On | 00000000:1C:00.0 Off | Off | | 30% 27C P8 22W / 300W | 1MiB / 48685MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ | 2 NVIDIA RTX A6000 On | 00000000:1D:00.0 Off | Off | | 30% 32C P8 23W / 300W | 1MiB / 48685MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ | 3 NVIDIA RTX A6000 On | 00000000:1E:00.0 Off | Off | | 30% 31C P8 23W / 300W | 1MiB / 48685MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ | 4 NVIDIA RTX A6000 On | 00000000:3D:00.0 Off | Off | | 30% 27C P8 22W / 300W | 1MiB / 48685MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ | 5 NVIDIA RTX A6000 On | 00000000:3F:00.0 Off | Off | | 30% 29C P8 23W / 300W | 1MiB / 48685MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ | 6 NVIDIA RTX A6000 On | 00000000:40:00.0 Off | Off | | 30% 27C P8 22W / 300W | 1MiB / 48685MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ | 7 NVIDIA RTX A6000 On | 00000000:41:00.0 Off | Off | | 30% 30C P8 22W / 300W | 1MiB / 48685MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+

### Description
I'm trying to train NMNIST with snntorch using multi GPU. since snntorch is based on torch package, I thought data parrallel from torch nn should work. 
here's whole code.

```python
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.init
import os
import torch.nn as nn
import time
import matplotlib.pyplot as plt
import tonic.transforms as transforms
import tonic
import numpy as np
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import spikeplot as splt
from snntorch import utils
import torch.nn as nn
import os
from torch.utils.data import DataLoader, random_split
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sensor_size = tonic.datasets.NMNIST.sensor_size

# Denoise removes isolated, one-off events
# time_window
frame_transform = transforms.ToFrame(sensor_size=sensor_size, time_window=1)

frame_transform = transforms.Compose([transforms.Denoise(filter_time=10000),
                                      transforms.ToFrame(sensor_size=sensor_size,
                                                         time_window=50000)
                                     ])

trainset = tonic.datasets.NMNIST(save_to='/home/hubo1024/PycharmProjects/snntorch/data/NMNIST', transform=frame_transform, train=True)
testset = tonic.datasets.NMNIST(save_to='./home/hubo1024/PycharmProjects/snntorch/data/NMNIST', transform=frame_transform, train=False)

# seed fix
torch.manual_seed(777)

# seed fix if gpu is available
if device == 'cuda':
    torch.cuda.manual_seed_all(777)

#batch_size = 100

batch_size = 32
dataset_size = len(trainset)
train_size = int(dataset_size * 0.9)
validation_size = int(dataset_size * 0.1)

trainset, valset = random_split(trainset, [train_size, validation_size])
print(len(valset))
print(len(trainset))
trainloader = DataLoader(trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(), shuffle=True)
valloader = DataLoader(valset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(), shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors())

spike_grad = surrogate.fast_sigmoid(slope=75)
beta = 0.5

class CNN(torch.nn.Module):

    def __init__(self):
        super(CNN, self).__init__()
        self.keep_prob = 0.5
        self.layer1 = torch.nn.Sequential(
            nn.Conv2d(2, 12, 5),
            nn.MaxPool2d(2),
            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
        )

        self.layer2 = torch.nn.Sequential(
            nn.Conv2d(12, 32, 5),
            nn.MaxPool2d(2),
            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
        )

        self.layer4 = torch.nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 5 * 5, 10),
            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
        )

    def forward(self, data):
        spk_rec = []
        layer1_rec = []
        layer2_rec = []
        utils.reset(self.layer1)  # resets hidden states for all LIF neurons in net
        utils.reset(self.layer2)
        utils.reset(self.layer4)

        for step in range(data.size(1)):  # data.size(0) = number of time steps
            input_torch = data[:, step, :, :, :]
            input_torch = input_torch.cuda()
            #print(input_torch)
            out = self.layer1(input_torch)
            #out1 = out

            out = self.layer2(out)
            #out2 = out
            out, mem = self.layer4(out)
            #out = self.layer4(out)

            spk_rec.append(out)

            #layer1_rec.append(out1)
            #layer2_rec.append(out2)

        return torch.stack(spk_rec)#, torch.stack(layer1_rec), torch.stack(layer2_rec)

model = CNN().to(device)
device_ids = [0, 1] #your GPU index
model = torch.nn.DataParallel(model, device_ids=device_ids)
#model = nn.DataParallel(model).to(device)
optimizer = torch.optim.NAdam(model.parameters(), lr=0.005,betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)
#model = nn.DataParallel(model)

total_batch = len(trainloader)
print('총 배치의 수 : {}'.format(total_batch))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)
num_epochs = 15
loss_hist = []
acc_hist = []
v_acc_hist = []
t_spk_rec_sum = []
start = time.time()
val_cnt = 0
v_acc_sum= 0
avg_loss = 0
index = 0
#################################################

for epoch in range(num_epochs):
    torch.save(model.state_dict(), '/home/hubo1024/PycharmProjects/snntorch/model_pt/Radam_15epoch-50000.pt')
    for i, (data, targets) in enumerate(iter(trainloader)):
        data = data.cuda()
        targets = targets.cuda()
        model.train()

        spk_rec = model(data)

        #print(spk_rec.shape)
        loss_val = loss_fn(spk_rec, targets)
        avg_loss += loss_val.item()
        optimizer.zero_grad()

        loss_val.backward()

        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())
        val_cnt = val_cnt+1
        #del loss_val

        if val_cnt == len(trainloader)/2-1:
            val_cnt=0

            for ii, (v_data, v_targets) in enumerate(iter(valloader)):
                v_data = v_data.to(device)
                v_targets = v_targets.to(device)

                v_spk_rec = model(v_data)
                #
                # print(t_spk_rec.shape)
                v_acc = SF.accuracy_rate(v_spk_rec, v_targets)
                del v_spk_rec
                if ii == 0:
                    v_acc_sum = v_acc
                    cnt = 1

                else:
                    v_acc_sum += v_acc
                    cnt += 1
                #del v_acc

            plt.plot(acc_hist)
            plt.plot(v_acc_hist)
            plt.legend(['train accuracy', 'validation accuracy'])
            plt.title("Train, Validation Accuracy-Radam 15epoch-50000")
            plt.xlabel("Iteration")
            plt.ylabel("Accuracy")
            # plt.show()
            plt.savefig('Radam_15epoch-50000.png')
            plt.clf()
            v_acc_sum = v_acc_sum/cnt

            # avg_loss = avg_loss / (len(trainloader) / 2)
            # print('average loss while half epoch', avg_loss)
            # if avg_loss <= 0.5:
            #     index = 1
            #     break
            # else:
            #     avg_loss = 0
            #     index = 0

        print('Radam-15epoch-50000')
        print("time :", time.time() - start,"sec")
        print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

        acc = SF.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)
        v_acc_hist.append(v_acc_sum)
        print(f"Train Accuracy: {acc * 100:.2f}%")
        print(f"Validation Accuracy: {v_acc_sum * 100:.2f}%\n")

    #     if index == 1:
    #         break
    # if index == 1:
    #     break
# 학습을 진행하지 않을 것이므로 torch.no_grad()
'''
with torch.no_grad():
    X_test = mnist_test.test_data.view(len(mnist_test), 1, 28, 28).float().to(device)
    Y_test = mnist_test.test_labels.to(device)

    prediction = model(X_test)
    correct_prediction = torch.argmax(prediction, 1) == Y_test
    accuracy = correct_prediction.float().mean()
    print('Accuracy:', accuracy.item())
'''

and here's error

(snn_torch) hubo1024@neuro:~/PycharmProjects/snntorch$ CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python gpu_6_run.py
6000
54000
총 배치의 수 : 13500
Traceback (most recent call last):
  File "/home/hubo1024/PycharmProjects/snntorch/gpu_6_run.py", line 146, in <module>
    spk_rec = model(data)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/_utils.py", line 461, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hubo1024/PycharmProjects/snntorch/gpu_6_run.py", line 102, in forward
    out = self.layerconv1(input_torch)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/snntorch/_neurons/leaky.py", line 162, in forward
    self.mem = self.state_fn(input_)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/snntorch/_neurons/leaky.py", line 201, in _build_state_function_hidden
    self._base_state_function_hidden(input_) - self.reset * self.threshold
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/snntorch/_neurons/leaky.py", line 195, in _base_state_function_hidden
    base_fn = self.beta.clamp(0, 1) * self.mem + input_
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/_tensor.py", line 1121, in __torch_function__
    ret = func(*args, **kwargs)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I rerun this code after removing snn.Leaky layer in CNN and it worked fine. (of course the cost doesn't converge and accuracy was 0% but still it runs) So I assume that the reason of this error is snn.Leaky layer. I think changing

ridgerchu commented 1 year ago

Hi,

I notice you use this line for distributed training: nn.DataParallel(model).to(device) I recommend you trying this version:

device_ids = [0, 1] #your GPU index
model = torch.nn.DataParallel(model, device_ids=device_ids)
rkdgmlqja commented 1 year ago

Hi, thank you for your help.

I tried with your

device_ids = [0, 1] #your GPU index
model = torch.nn.DataParallel(model, device_ids=device_ids)

code and same error appeared.

Traceback (most recent call last):
  File "/home/hubo1024/PycharmProjects/snntorch/15epoch_50k.py", line 149, in <module>
    spk_rec = model(data)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/_utils.py", line 461, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hubo1024/PycharmProjects/snntorch/15epoch_50k.py", line 98, in forward
    out = self.layer1(input_torch)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/snntorch/_neurons/leaky.py", line 162, in forward
    self.mem = self.state_fn(input_)
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/snntorch/_neurons/leaky.py", line 201, in _build_state_function_hidden
    self._base_state_function_hidden(input_) - self.reset * self.threshold
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/snntorch/_neurons/leaky.py", line 195, in _base_state_function_hidden
    base_fn = self.beta.clamp(0, 1) * self.mem + input_
  File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/_tensor.py", line 1121, in __torch_function__
    ret = func(*args, **kwargs)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Process finished with exit code 1

aslo I udated the enviroment info aobove just in case

thank you again.

ridgerchu commented 1 year ago

Hi,

I debug with the Leaky class. You are right, The value of self.mem variable of Leaky class has been altered unexpected when employ the DataParallel, causing this error. However, I have not found the reason yet. To solve this temporarily, I recommend you to change the source code slightly of the Leaky class. You can find this in line 161 of the snntorch/_neurons/leaky.py:

...
        if self.init_hidden:
            self._leaky_forward_cases(mem)
            self.reset = self.mem_reset(self.mem)
            self.mem = self._state_fn(input_)
...

You can alter them with this version:

...
        if self.init_hidden:
            self._leaky_forward_cases(mem)
            self.reset = self.mem_reset(self.mem)
            self.mem = self._build_state_function_hidden(input_)
...

This works correct in my Dual-GPU workstation.

ridgerchu commented 1 year ago

Hi, I've pushed a PR for fixing this issue, after merging into the master branch, you can clone it and the problem will be solved! #156

jeshraghian commented 1 year ago

Made the same fix for other neurons too. #161