IBM / aihwkit

IBM Analog Hardware Acceleration Kit
https://aihwkit.readthedocs.io
Apache License 2.0
338 stars 142 forks source link

Backward pass not ideal when using InferenceRPUConfig #628

Open HeatPhoenix opened 5 months ago

HeatPhoenix commented 5 months ago

Description

From my understanding of InferenceRPUConfig, training using this RPUConfig should create a situation where the forward pass is noisy, but the backward pass is ideal. As such, shouldn't rpu_config.backward.is_perfect = True not have any effect? As by definition of InferenceRPUConfig the backward should be perfect by default.

Instead, with and without rpu_config.backward.is_perfect set to True produces different results.

backward.is_perfect = True Inference plot image vs. default settings on InferenceRPUConfig image

And loss development: backwards perfect image default image

How to reproduce

I am training a simple regression network like so:

neurons = 128
model = nn.Sequential(
          nn.Linear(6,neurons),
          nn.Softplus(),
          nn.Linear(neurons,neurons),
          nn.Softplus(),  
          nn.Linear(neurons,neurons),
          nn.Softplus(),  
          nn.Linear(neurons,3),
        )

With the following RPUConfig:

# Define a single-layer network, using inference/hardware-aware training tile
rpu_config = InferenceRPUConfig()
rpu_config.forward.out_res = -1.0  # Turn off (output) ADC discretization.
rpu_config.forward.w_noise_type = WeightNoiseType.ADDITIVE_CONSTANT
rpu_config.forward.w_noise = 0.02  # Short-term w-noise.
# Inference noise model.
rpu_config.noise_model = PCMLikeNoiseModel(g_max=25.0)
# drift compensation
rpu_config.drift_compensation = GlobalDriftCompensation()
rpu_config.backward.is_perfect = True #or commented out for default behavior

Training is done with AnalogAdam for 100 epochs to create the above plots. The neural network is confirmed to work in fully digital (also when both forward and backward is_perfect).

Expected behavior

I would expect backward.is_perfect to have no effect in Inference-only setups. But instead, it has a very significant effect. Is this writing noise? Reading noise from getting the activations? The documentation and communication on GitHub states things like "the backward pass and update is thought to be perfect but noise is injected in the forward pass only; see InferenceRPUConfig"

Other information

maljoras commented 5 months ago

That's indeed surprising. You are right it should not have any effect. Thanks for raising the issue.

jubueche commented 1 month ago

Hi. I ran this sample script:

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import aihwkit
from aihwkit.inference.compensation.drift import GlobalDriftCompensation
from aihwkit.inference.noise.pcm import PCMLikeNoiseModel
from aihwkit.nn.conversion import convert_to_analog
from aihwkit.optim.analog_optimizer import AnalogSGD
from aihwkit.simulator.configs.configs import InferenceRPUConfig
from aihwkit.simulator.parameters.enums import WeightNoiseType

def instantiate_model(neurons=128):
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 32, 3, 1)
            self.conv2 = nn.Conv2d(32, 64, 3, 1)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            self.fc1 = nn.Linear(9216, 128)
            self.fc2 = nn.Linear(128, 10)
        def forward(self, x):
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)
            # output = F.log_softmax(x, dim=1)
            # return output
            return x
    return Net()

def gen_rpu_config(is_perfect):
    rpu_config = InferenceRPUConfig()
    rpu_config.forward.out_res = -1.0  # Turn off (output) ADC discretization.
    rpu_config.forward.w_noise_type = WeightNoiseType.ADDITIVE_CONSTANT
    rpu_config.forward.w_noise = 0.02  # Short-term w-noise.
    # Inference noise model.
    rpu_config.noise_model = PCMLikeNoiseModel(g_max=25.0)
    # Drift compensation
    rpu_config.drift_compensation = GlobalDriftCompensation()
    rpu_config.backward.is_perfect = is_perfect
    return rpu_config

def get_data():
    transform= torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,))
    ])
    dataset = torchvision.datasets.MNIST('data', train=True, download=True,
                       transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)
    return train_loader
if __name__ == '__main__':
    cuda = True
    model = instantiate_model()
    analog_models = [
        convert_to_analog(model, gen_rpu_config(is_perfect=True)),
        convert_to_analog(model, gen_rpu_config(is_perfect=False)),
    ]
    optimizers = []
    for analog_model in analog_models:
        if cuda:
            analog_model.cuda()
        analog_model.train()
        optimizer = AnalogSGD(analog_model.parameters(), lr=0.001)
        optimizers.append(optimizer)
    batch_size = 64
    n_iters = -1
    logging_interval = 25
    criterion = nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(get_data()):
        if cuda:
            data, target = data.cuda(), target.cuda()
        gradients_mean = []
        gradients_std = []
        losses = []
        for analog_model, optimizer in zip(analog_models, optimizers):
            optimizer.zero_grad()
            model_output = analog_model(data)
            loss = criterion(model_output, target)
            loss.backward()
            optimizer.step()
            gradients = torch.tensor([])
            for p in analog_model.parameters():
                if p.grad is not None:
                    p_grad = p.grad.detach().cpu().data.flatten()
                    gradients = torch.cat([gradients, p_grad], dim=-1)
            gradients_mean.append(gradients.mean())
            gradients_std.append(gradients.std())
            losses.append(loss.item())
        if batch_idx % logging_interval == 0:
            print('Iteration {} mean: '.format(batch_idx), gradients_mean, ', std: ', gradients_std, 'loss: ', losses)
        if batch_idx == n_iters - 1:
            break

and this is the log I get:

Iteration 0 mean:  [tensor(0.0003), tensor(0.0002)] , std:  [tensor(0.0017), tensor(0.0023)] loss:  [2.440281629562378, 2.4453179836273193]
Iteration 25 mean:  [tensor(0.0003), tensor(0.0002)] , std:  [tensor(0.0015), tensor(0.0024)] loss:  [2.4614646434783936, 2.4668071269989014]
Iteration 50 mean:  [tensor(0.0002), tensor(0.0002)] , std:  [tensor(0.0015), tensor(0.0023)] loss:  [2.465376615524292, 2.409703016281128]
Iteration 75 mean:  [tensor(0.0003), tensor(0.0002)] , std:  [tensor(0.0014), tensor(0.0022)] loss:  [2.433210611343384, 2.42989444732666]
Iteration 100 mean:  [tensor(0.0002), tensor(0.0001)] , std:  [tensor(0.0014), tensor(0.0022)] loss:  [2.3848493099212646, 2.417853355407715]
Iteration 125 mean:  [tensor(1.3142e-05), tensor(0.0001)] , std:  [tensor(0.0014), tensor(0.0021)] loss:  [2.3443965911865234, 2.314314365386963]
Iteration 150 mean:  [tensor(0.0001), tensor(8.9972e-05)] , std:  [tensor(0.0015), tensor(0.0023)] loss:  [2.37026309967041, 2.2686352729797363]
Iteration 175 mean:  [tensor(1.2524e-05), tensor(5.3390e-05)] , std:  [tensor(0.0014), tensor(0.0021)] loss:  [2.365387201309204, 2.37935733795166]
Iteration 200 mean:  [tensor(8.2563e-05), tensor(-4.6383e-06)] , std:  [tensor(0.0014), tensor(0.0023)] loss:  [2.332214117050171, 2.349396228790283]
Iteration 225 mean:  [tensor(0.0001), tensor(0.0002)] , std:  [tensor(0.0015), tensor(0.0023)] loss:  [2.3589439392089844, 2.356438398361206]

The first entry in the loss is is_perfect=True and the second is False. I don't see a clear difference.

Are you sure you are setting the rpu_config? When you pass None to convert_to_analog, another default Tile class will be used.

PabloCarmona commented 3 weeks ago

Hi @HeatPhoenix! Did you have the chance to try out and look at the script @jubueche suggested?

Give us some feedback when you can as we can try to help you more if any problems arise. Thank you!