jeshraghian / snntorch

Deep and online learning with spiking neural networks in Python
https://snntorch.readthedocs.io/en/latest/
MIT License
1.37k stars 228 forks source link

Spiking VGG-16 CUDA Out of Memory / Multi-GPU Training #240

Open tapantabhanja opened 1 year ago

tapantabhanja commented 1 year ago

Description

The idea was to convert a VGG-16 network to its equivalent spiking version and train it with the cats and dogs image dataset.

The initial problem was that the model did not load on my GPU. It threw the error CUDA Out-Of-memory. Although, my model size was much less than the GPU memory. I could not detect where the problem in the code lies. To go around, I thought using a multi-gpu training would be a good idea. The Deepspeed strategy might help break my model down and allocate it to multiple GPUs and thus help train my model. I tried out the PyTorch Lightning Library which provides an easier API to implement Multi-GPU training. But this also was unsuccessful due to timed out initialisation of process groups. My code and staketrace follow:

What I Did

My code:

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms 
from torch.nn import functional as F
from torchvision import datasets
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import ToTensor
from torchvision.models import vgg16, VGG16_Weights
import matplotlib.pyplot as plt
import glob
import os
import random
from PIL import Image
import time

import snntorch as snn
from snntorch import surrogate
from snntorch import spikegen
from snntorch import functional as SF
from snntorch import utils

import pytorch_custom_utils

import lightning as L 
from lightning.fabric.strategies.deepspeed import DeepSpeedStrategy

start_time = time.time()

num_steps = 10

# Creating the Dataset Class
class Cat_n_DogDataset(Dataset):

    def __init__(self, data, transform = None) -> None:

        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):

        img_path = self.data[index]

        img = Image.open(img_path)

        # Loading the label
        if img_path.split('/')[-2] == 'Dog':
            label = 0

        elif img_path.split('/')[-2] == 'Cat':
            label = 1

        if self.transform:
            image = self.transform(img)

        return image, label

# Creating the Spiking VGG16.
class Spiking_VGG16(nn.Module):
    """
    Implementation of Spiking VGG16. 
    """
    def __init__(self) -> None:
        super(Spiking_VGG16, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True)
        )    # Look at how to change the input channels for neuromorphic images

        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True)
        )

        self.layer4 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.layer5 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True)
        )

        self.layer6 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True)
        )

        self.layer7 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.layer8 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True)
        ) 

        self.layer9 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True)
        )

        self.layer10 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.layer11 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True)
        )

        self.layer12 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True)
        )

        self.layer13 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fclayer14 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features=25088, out_features=4096),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True)
        )

        self.fclayer15 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features=4096, out_features=4096),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True)
        )

        self.fclayer16 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features=4096, out_features=2),
            snn.Leaky(beta=0.5, spike_grad=surrogate.atan(), init_hidden=True, output=True)
        )

    def forward(self, x):

        for step in range(num_steps):
            out = self.layer1(x)
            out = self.layer2(out)
            out = self.layer3(out)
            out = self.layer4(out)
            out = self.layer5(out)
            out = self.layer6(out)
            out = self.layer7(out)
            out = self.layer8(out)
            out = self.layer9(out)
            out = self.layer10(out)
            out = self.layer11(out)
            out = self.layer12(out)
            out = self.layer13(out)
            out = out.reshape(out.size(0), -1)
            out = self.fclayer14(out)
            out = self.fclayer15(out)
            out = self.fclayer16(out)

        return out

def forward_pass(net, num_steps, data):
  mem_rec = []
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps):
      spk_out, mem_out = net(data)
      spk_rec.append(spk_out)
      mem_rec.append(mem_out)

  return torch.stack(spk_rec), torch.stack(mem_rec)

# Creating a Cats & Dogs Dataset
cats = []
dogs = []

train_data = []
test_data = []

# Number of Epochs
num_epochs = 10

path = './data/Cat_n_Dog/PetImages/'

# For Cats
for root, dirs, files in os.walk(path + 'Cat/'):

    for file in files:
        cats.append(os.path.join(root, file))

cats_shuffle = random.sample(cats, len(cats))

# For Dogs
for root, dirs, files in os.walk(path + 'Dog/'):

    for file in files:
        dogs.append(os.path.join(root, file))

dogs_shuffle = random.sample(dogs, len(dogs))

# Creating the train dataset for cats and dogs:
# For Dogs 
dog_train = dogs_shuffle[:int(0.7*len(dogs_shuffle))]
#dog_train = dogs_shuffle[:int(0.01*len(dogs_shuffle))]

# For Cats
cat_train = cats_shuffle[:int(0.7*len(cats_shuffle))]
#cat_train = cats_shuffle[:int(0.01*len(cats_shuffle))]

# Creating the test dataset for cats and dogs:
# For Dogs
dog_test = dogs_shuffle[int(0.7*len(dogs_shuffle)):]
#dog_test = dogs_shuffle[int(0.999*len(dogs_shuffle)):]
# For Cats
cat_test = cats_shuffle[int(0.7*len(cats_shuffle)):]
#cat_test = cats_shuffle[int(0.999*len(cats_shuffle)):]

print("Length of the Train Dataset: ", len(dog_train) + len(cat_train))
print("Length of the Test Dataset: ", len(dog_test) + len(cat_test))

# Finally, merging cat_train & dog_train to create the train dataset
train_data = cat_train + dog_train

# Merging cat_test & dog_test to create the test dataset
test_data = cat_test + dog_test

# Random Sampling the Train Dataset
train_data_shuffle = random.sample(train_data, len(train_data))

# Random Sampling the Test Dataset
test_data_shuffle = random.sample(test_data, len(test_data))

# Creating the Data Transforms
# Norming the Images according to the mean and standard deviation of image net. 
imgnet_norm = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.255])

train_transforms = transforms.Compose([
#    transforms.ToPILImage(),
    transforms.Resize(224), 
    transforms.CenterCrop(224), 
    transforms.ToTensor(), 
#    transforms.Normalize(*imgnet_norm)
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.255]
    )
])

test_transforms = transforms.Compose([
#    transforms.ToPILImage(),
    transforms.Resize(224), 
    transforms.CenterCrop(224), 
    transforms.ToTensor(), 
#    transforms.Normalize(*imgnet_norm)
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.255]
    )
])

cats_n_dogs_train = Cat_n_DogDataset(data=train_data_shuffle, transform=train_transforms)
cats_n_dogs_test = Cat_n_DogDataset(data=test_data_shuffle, transform=test_transforms)

# Creating the Train and Test DataLoader. 
train_data_loader = DataLoader(cats_n_dogs_train, batch_size=8, num_workers=0, pin_memory=False, shuffle=True)
test_data_loader = DataLoader(cats_n_dogs_test, batch_size=8, num_workers=0, pin_memory=False, shuffle=True)

# Adding PyTorch Lightning to the code.
deepspeed_strategy = DeepSpeedStrategy(process_group_backend="nccl")
fabric = L.Fabric(accelerator="cuda", devices=4, strategy=deepspeed_strategy)
fabric.launch()

# Initialising GPU
# device = torch.device(device="cuda" if torch.cuda.is_available() else 'cpu')
# print(device)

# Initialising the Model
n_model = Spiking_VGG16()

# Defining Loss 
loss = nn.CrossEntropyLoss()

# Defining Optimizer
optimizer = optim.Adam(n_model.parameters(), lr=5e-4, betas=(0.9, 0.999))

# Pushing model to GPU
# n_model = n_model.to(device)

# Adding further lines for fabric
n_model, optimizer = fabric.setup(n_model, optimizer)
train_data_loader = fabric.setup_dataloaders(train_data_loader)
test_data_loader = fabric.setup_dataloaders(test_data_loader)

loss_hist = []
test_loss_hist = []
counter = 0
# avg_epoch_train_accuracy = 0

n_model.train()

# Outer Training Loop
for epoch in range(num_epochs):

    avg_epoch_train_accuracy = 0
    avg_epoch_train_loss = 0

    mini_batch_accuracy = 0
    mini_batch_loss = []
    mini_batch_total = 0

    # Training Loop
    for train_batch in train_data_loader:

        train_images, train_labels = train_batch

        # Converting Training Images to Spiking Images
        # spike_train_images = spikegen.rate(train_images, num_steps=100)

        # Pushing Spiking Training Images and Labels to GPUs
        # spike_train_images = spike_train_images.to(device=device)
        #train_images = train_images.to(device)
        #train_labels = train_labels.to(device)

        # Reseting hidden states for all LIF neurons in SNN
        # utils.reset(n_model)

        print("Test ",train_images.element_size())

        spk_record, mem_pot_record = n_model(train_images)

        print("Shape of the spiking records: ", spk_record.shape)
        print("Shape of the Membrane Potential records: ", mem_pot_record.shape)
        print("Shape of the train labels: ", train_labels.shape)
        print("Shape of the Membrane Potential Records per step: ", mem_pot_record[0].shape)

        # Initialise the loss and sum over time
        # loss_val = torch.zeros((1), device=device)

        # for step in range(num_steps):
        #    loss_val += loss(mem_pot_record[step], train_labels)
        loss_val = loss(mem_pot_record, train_labels)

        # Gradient Calculation + Weight Update
        optimizer.zero_grad()
        # loss_val.backward()

        # Adding Fabric Loss Line
        fabric.backward(loss)

        optimizer.step()

        # Store loss history
        loss_hist.append(loss_val.item())

        # Mini batch Accuracy
        mini_batch_accuracy += SF.accuracy_rate(spk_record, train_labels)*spk_record.size(1)

        # Mini Batch Loss
        mini_batch_loss.append(loss_val) 

        mini_batch_total += spk_record.size(1)

        torch.cuda.empty_cache()

    avg_epoch_train_accuracy = mini_batch_accuracy/mini_batch_total
    avg_epoch_train_loss = sum(mini_batch_loss)/mini_batch_total

    print("Epoch = {}, Training Loss = {}, Training Accuracy = {}".format(epoch, avg_epoch_train_loss, avg_epoch_train_accuracy))

print("Finished Training")

# Testing Block
with torch.no_grad():

    avg_test_accuracy = 0

    test_mini_batch_accuracy = 0
    test_mini_batch_total = 0

    # Testing Loop
    for test_batch in test_data_loader:

        test_images, test_labels = test_batch

        # Pushing Test Images and Labels to GPUs
        # test_images = test_images.to(device)
        # test_labels = test_labels.to(device)

        test_spk_record, test_mem_record = n_model(test_images)

        # Test Mini-Batch Accuracy
        test_mini_batch_accuracy += SF.accuracy_rate(test_spk_record, test_labels)*test_spk_record.size(1)

        test_mini_batch_total += test_spk_record.size(1)

    avg_test_accuracy = test_mini_batch_accuracy/test_mini_batch_total

    print("Test Accuracy: ", avg_test_accuracy)

torch.cuda.empty_cache()

end_time = time.time()

print("Time Required for the execution: ", (end_time-start_time)/60)

My staketrace for CUDA Out of Memory error: File "/work/bhanja/example_training/training_neuro_cats_dogs.py", line 386, in <module> spk_record, mem_pot_record = n_model(train_images) File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/work/bhanja/example_training/training_neuro_cats_dogs.py", line 183, in forward out = self.layer2(out) File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/snntorch/_neurons/leaky.py", line 193, in forward self.reset = self.mem_reset(self.mem) File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/snntorch/_neurons/neurons.py", line 107, in mem_reset reset = self.spike_grad(mem_shift).clone().detach() File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/snntorch/surrogate.py", line 210, in inner return ATan.apply(x, alpha) File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/torch/autograd/function.py", line 506, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/snntorch/surrogate.py", line 189, in forward out = (input_ > 0).float() torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 98.00 MiB (GPU 0; 31.74 GiB total capacity; 30.76 GiB already allocated; 3.12 MiB free; 31.30 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

My staketrace for Timed Out Initialisation of Process Groups: File "/work/bhanja/example_training/training_neuro_cats_dogs.py", line 336, in <module> fabric.launch() File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/lightning/fabric/fabric.py", line 664, in launch return self._strategy.launcher.launch(function, *args, **kwargs) File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/lightning/fabric/strategies/launchers/subprocess_script.py", line 90, in launch return function(*args, **kwargs) File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/lightning/fabric/fabric.py", line 749, in _run_with_setup self._strategy.setup_environment() File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/lightning/fabric/strategies/ddp.py", line 113, in setup_environment self._setup_distributed() File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/lightning/fabric/strategies/deepspeed.py", line 576, in _setup_distributed self._init_deepspeed_distributed() File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/lightning/fabric/strategies/deepspeed.py", line 594, in _init_deepspeed_distributed deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port) File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/deepspeed/comm/comm.py", line 670, in init_distributed cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size) File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/deepspeed/comm/torch.py", line 116, in __init__ self.init_process_group(backend, timeout, init_method, rank, world_size) File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/deepspeed/comm/torch.py", line 142, in init_process_group torch.distributed.init_process_group(backend, File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 932, in init_process_group _store_based_barrier(rank, store, timeout) File "/home/bhanja/.conda/envs/spikingjelly/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 469, in _store_based_barrier raise RuntimeError( RuntimeError: Timed out initializing process group in store based barrier on rank: 0, for key: store_based_barrier_key:1 (world_size=8, worker_count=1, timeout=0:30:00)

What am I doing wrong? I could not find a solution for this.

ahenkes1 commented 10 months ago

@tapantabhanja , have you resolved the problem?

tapantabhanja commented 10 months ago

@ahenkes1 Unfortunately no. After this did not work on my laptop, I made a banal attempt of solving the CUDA Out of Memory by running the code on an HPC Cluster. The GPUs I used there were much powerful and had more memory to use. Unfortunately, I had the same error there too. Which made me realise that the problem stemmed from somewhere else. I still could not figure out this.

I would be so much grateful for some help.