fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.24k stars 235 forks source link

Gradient Accumulation Training Problem #495

Closed tahafkh closed 4 months ago

tahafkh commented 5 months ago

Issue type

SpikingJelly version

0.0.0.0.15

Description

Hi! I was trying to train my SNN model using gradient accumulation technique, since I have memory limitation problem. However, during training, I got an error when backward method was called, which looks like this:

File "/home/tahaf/Documents/Lip_Reading_SNN/utils.py", line 301, in train
  loss.backward()
File "/home/tahaf/.local/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
  torch.autograd.backward(
File "/home/tahaf/.local/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'StackBackward0' returned nan values in its 7th output.

I'm using this link(the second method) and this link to implement this technique. Here is my code for training

# Train for one epoch
def train(model, device, train_loader, optimizer, num_labels=70, scheduler=None):
    model.train()
    accumulation_steps = 8
    train_loss =0
    correct=0
    cpt=0
    optimizer.zero_grad()
    for i, (data, target) in enumerate(tqdm(train_loader, leave=False)):
        cpt += 1
        data, target = data.transpose(0, 1).float().to(device), target.to(device)

        output = model(data)
        output = output.transpose(0,1).mean(1)
        label_onehot = torch.nn.functional.one_hot(target, num_labels).float()
        loss = torch.nn.functional.cross_entropy(output, label_onehot)

        pred = np.array(output.cpu().argmax(dim=1, keepdim=True)).flatten()  
        for p in range(len(pred)):
            correct += int(pred[p]==target[p].item())

        train_loss += loss.item()
        loss /= accumulation_steps
        loss.backward()

        if (i+1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        functional.reset_net(model)

    if scheduler is not None:
        scheduler.step()

    accuracy = 100.0 * correct / len(train_loader.dataset)
    train_loss /= cpt

    return train_loss, accuracy

I'm using batch size = 4 and accumulation steps = 8. Any ideas for the source of this error? or any ready-to-use recipe that you've used for your own projects?

Thanks for your help!

fangwei123456 commented 5 months ago

You can check if torch.nn.functional.cross_entropy(output, label_onehot)` gets some nan values?

tahafkh commented 4 months ago

Hi, I check the input of the loss function, and it seems to be true. However, I made a new run without this technique and that run also crashed. This made me think that this error could be from memory, and by reducing the number of parameters in the network, I was able to run my model using this technique without any problems. Thanks!