IntelLabs / bayesian-torch

A library for Bayesian neural network layers and uncertainty estimation in Deep Learning extending the core of PyTorch
BSD 3-Clause "New" or "Revised" License
544 stars 73 forks source link

dnn_to_bnn() does not support Conv2d with non-square kernels #32

Closed bryanmooremd closed 1 year ago

bryanmooremd commented 1 year ago

Base model has a layer with torch.nn.Conv2d(1,8,(32,1),stride=(16,1)) that works perfectly. After dnn_to_bnn(model), I get the following error:

RuntimeError: Calculated padded input size per channel: (256 x 8). Kernel size: (32 x 32). Kernel size can't be greater than actual input size.

It appears the error is coming from dnn_to_bnn() using wrong kernel size extracted from base model Conv layers.

msubedar commented 1 year ago

@bryanmooremd Could you please confirm you are using the latest version of the code? Can you please share the model definition and complete log?

bryanmooremd commented 1 year ago

@msubedar

Version being used is bayesian-torch==0.2.1

Model definition:

# Weighted BCE for class imbalance in spike vs non-spike
def weighted_binary_cross_entropy(output, target, epsilon, weights=None):
    output = torch.clamp(output,min=epsilon,max=(1-epsilon))
    if weights is not None:
        assert len(weights) == 2
        loss = weights[1] * (target * torch.log(output)) + \
               weights[0] * ((1 - target) * torch.log(1 - output))
    else:
        loss = target * torch.log(output) + (1 - target) * torch.log(1 - output)
    return(torch.neg(torch.mean(loss)))

class subnet(torch.nn.Module):
    def __init__(self):
        super(subnet, self).__init__()
        self.bn1 = torch.nn.BatchNorm1d((32+16))
        self.fc1 = torch.nn.Linear((32+16),1)
    def forward(self,x_k2,fb,x_first_ord):
        x_ff = torch.cat((x_first_ord,x_k2),1)
        x_ff = self.bn1(x_ff)
        u_pred = self.fc1(x_ff)
        return(u_pred)

class shared_k2(torch.nn.Module):
    def __init__(self):
        super(shared_k2, self).__init__()
        self.bn1 = torch.nn.BatchNorm1d(16)
        self.bn2 = torch.nn.BatchNorm1d(32)
        self.bn3 = torch.nn.BatchNorm1d(32)
        self.fc1 = torch.nn.Linear(16,32)
        self.fc2 = torch.nn.Linear(32,32)
        self.fc3 = torch.nn.Linear(32,16)
        self.activation1 = torch.nn.ReLU()
        self.activation2 = torch.nn.ReLU()
    def forward(self,x):
        batch_sz = x.shape[0]
        input_neurons = x.shape[1] # 16
        x = self.bn1(x)
        x = self.bn2(self.activation1(self.fc1(x)))
        x = self.bn3(self.activation2(self.fc2(x)))
        x = self.fc3(x)
        return(x)

class v19(torch.nn.Module):
    def __init__(self):
        super(v19, self).__init__()
        self.conv0 = torch.nn.Conv2d(1,8,(32,1),stride=(16,1)) # In=(batch,1,256,8) --> Out=(batch,8,15,8) . . . This layer accomplishes temporal smoothing and shared basis use
        self.bn1 = torch.nn.BatchNorm2d(8)
        self.conv1 = torch.nn.Conv2d(8,4,(15,1),stride=(1,1)) # In=(batch,8,15,8) --> Out=(batch,4,1,8) . . . This layer has global temporal filters shared across all input neurons
        self.bn2 = torch.nn.BatchNorm2d(4)
        self.conv2 = torch.nn.Conv2d(8,(4*8),(4,1),stride=(1,1),groups=8) # In=(batch,8,4,1) --> Out=(batch,8*4,1,1)
        # Define shared models for k2s, k2x
        self.shared_model_k2 = shared_k2()
        self.unshared1 = subnet()
        self.unshared2 = subnet()
        self.unshared3 = subnet()
        self.unshared4 = subnet()
        self.unshared5 = subnet()
        self.unshared6 = subnet()
        self.unshared7 = subnet()
        self.unshared8 = subnet()
        self.unshared9 = subnet()
        self.unshared10 = subnet()
        self.unshared11 = subnet()
        self.unshared12 = subnet()
        self.unshared13 = subnet()
        self.unshared14 = subnet()
        self.unshared15 = subnet()
        self.unshared16 = subnet()
        self.unshared17 = subnet()
        self.unshared18 = subnet()
        self.unshared19 = subnet()
        self.unshared20 = subnet()
        self.unshared21 = subnet()
        self.unshared22 = subnet()
        self.unshared23 = subnet()
        self.unshared24 = subnet()
        self.unshared25 = subnet()
        self.unshared26 = subnet()
        self.unshared27 = subnet()
        self.unshared28 = subnet()
        self.unshared29 = subnet()
        self.unshared30 = subnet()
        self.unshared31 = subnet()
        self.unshared32 = subnet()
    def forward(self,x):
        x = x.float()
        x = x[:,:,:,:8]
        x = self.conv0(x)
        x = self.bn1(x)
        x = self.conv1(x)
        x = self.bn2(x)
        x = x.permute(0,3,1,2)
        x = self.conv2(x)
        x_FO = x.reshape(x.shape[0],x.shape[1]) # batch, N
        x_DR = self.fc_dr(x_FO)
        shared_out_k2 = self.shared_model_k2(x_DR)
        u1_pred = self.unshared1(shared_out_k2,x_FO) # x_k2,x_k2x,fb,x_first_ord)
        u2_pred = self.unshared2(shared_out_k2,x_FO)
        u3_pred = self.unshared3(shared_out_k2,x_FO)
        u4_pred = self.unshared4(shared_out_k2,x_FO)
        u5_pred = self.unshared5(shared_out_k2,x_FO)
        u6_pred = self.unshared6(shared_out_k2,x_FO)
        u7_pred = self.unshared7(shared_out_k2,x_FO)
        u8_pred = self.unshared8(shared_out_k2,x_FO)
        u9_pred = self.unshared9(shared_out_k2,x_FO)
        u10_pred = self.unshared10(shared_out_k2,x_FO)
        u11_pred = self.unshared11(shared_out_k2,x_FO)
        u12_pred = self.unshared12(shared_out_k2,x_FO)
        u13_pred = self.unshared13(shared_out_k2,x_FO)
        u14_pred = self.unshared14(shared_out_k2,x_FO)
        u15_pred = self.unshared15(shared_out_k2,x_FO)
        u16_pred = self.unshared16(shared_out_k2,x_FO)
        u17_pred = self.unshared17(shared_out_k2,x_FO)
        u18_pred = self.unshared18(shared_out_k2,x_FO)
        u19_pred = self.unshared19(shared_out_k2,x_FO)
        u20_pred = self.unshared20(shared_out_k2,x_FO)
        u21_pred = self.unshared21(shared_out_k2,x_FO)
        u22_pred = self.unshared22(shared_out_k2,x_FO)
        u23_pred = self.unshared23(shared_out_k2,x_FO)
        u24_pred = self.unshared24(shared_out_k2,x_FO)
        u25_pred = self.unshared25(shared_out_k2,x_FO)
        u26_pred = self.unshared26(shared_out_k2,x_FO)
        u27_pred = self.unshared27(shared_out_k2,x_FO)
        u28_pred = self.unshared28(shared_out_k2,x_FO)
        u29_pred = self.unshared29(shared_out_k2,x_FO)
        u30_pred = self.unshared30(shared_out_k2,x_FO)
        u31_pred = self.unshared31(shared_out_k2,x_FO)
        u32_pred = self.unshared32(shared_out_k2,x_FO)
        u_pred = torch.cat((u1_pred,u2_pred,u3_pred,u4_pred,u5_pred,u6_pred,u7_pred,u8_pred,u9_pred,u10_pred,u11_pred,u12_pred,u13_pred,u14_pred,u15_pred,u16_pred,u17_pred,u18_pred,u19_pred,u20_pred,u21_pred,u22_pred,u23_pred,u24_pred,u25_pred,u26_pred,u27_pred,u28_pred,u29_pred,u30_pred,u31_pred,u32_pred),1)
        norm = torch.distributions.normal.Normal(0, 1)
        p_pred = norm.cdf(u_pred)
        return(p_pred)

# Define Training Loop
def train(args, model, device, train_loader, optimizer, epoch):
    # Switch model to training mode from evaluation mode so that layers like dropout and batchnorm behave correctly in training mode. Also make model use cuda.
    model.train()
    model.cuda()
    loss_epoch = 0
    # Loop over the data iterator and feed the inputs to the model and adjust the weights.
    for batch_idx, (data, target) in enumerate(train_loader):
        # Load the input features and labels from the trainig dataset
        data, target = data.to(device), target.to(device)
        # Reset the gradients to 0 for all learnable weight parameters
        optimizer.zero_grad()
        # Forward pass: Pass each time step from the training dataset and make spike prediction
        output = model(data)
        kl = get_kl_loss(model)
        # Define our loss function and compute loss
        loss = weighted_binary_cross_entropy(output, target, 5e-8, weights=[1,1])+kl
        loss_epoch += loss.item()
        # Backward pass: compute the gradient of the loss wrt the model's parameters
        loss.backward()
        # Update model weights
        optimizer.step()
    # Log to WandB
    wandb.log({'Loss': loss_epoch, 'Epoch': epoch})
    # Return loss_epoch to update best_loss
    return(loss_epoch)
def test(args,model,device,test_loader,epoch):
    model.eval()
    validation_loss_epoch = 0
    with torch.no_grad():
        for batch_idx, (data,target) in enumerate(test_loader):
            data,target = data.to(device), target.to(device)
            output = model(data)
            kl = get_kl_loss(model)
            validation_loss = weighted_binary_cross_entropy(output, target, 5e-8, weights=[1,1])+kl
            validation_loss_epoch += validation_loss.item()
    wandb.log({'Validation Loss': validation_loss_epoch, 'Epoch': epoch})
    return(validation_loss_epoch)

# Set-up Weights & Biases for training visualization
wandb.init(project='Deep_MIMO',name="v19_K036U175_8in_32out")
wb_config = wandb.config # Config is a variable that holds and saves hyperparameters and inputs
wb_config.learn_rate = 1e-3
wb_config.epochs = 50
wb_config.batch_size = 2000
wb_config.validation_batch_size = 2000
wb_config.dropout = 0
wb_config.log_interval = 1
wb_config.Activation = 'ReLU'
wb_config.Optimizer = 'SGD'
wb_config.Num_GPUs = '2'

def main():
    # Load the dataset
    train_loader = torch.utils.data.DataLoader(my_dataset_train, batch_size=wb_config.batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(my_dataset_val, batch_size=wb_config.batch_size, shuffle=True)
    # Initialize our model and recursively go over all modules and convert their parameters and buffers to CUDA tensors
    set_seeds(6)
    model = v19()
    const_bnn_prior_parameters = {
        "prior_mu": 0.0,
        "prior_sigma": 1.0,
        "posterior_mu_init": 0.0,
        "posterior_rho_init": -3.0,
        "type": "Reparameterization",  # Flipout or Reparameterization
        "moped_enable": False,  # True to initialize mu/sigma from the pretrained dnn weights
        "moped_delta": 0.5}
    model = nn.DataParallel(model)
    model.to(device)
    dnn_to_bnn(model,const_bnn_prior_parameters)
    #optimizer = torch.optim.SGD(model.parameters(), lr=wb_config.learn_rate, momentum=0.9)
    optimizer = torch.optim.Adam(model.parameters(), lr=wb_config.learn_rate)
    wandb.watch(model, log='all')
    # Initialize best_loss and best_val_loss
    best_loss = 9999999999
    best_val_loss = 9999999999
    for epoch in range (1, wb_config.epochs):
        start = time.time()
        this_train_loss = train(wb_config, model, device, train_loader, optimizer, epoch)
        this_val_loss = test(wb_config, model, device, test_loader, epoch)
        end = time.time()
        epoch_time = np.round((end-start),1)
        if this_train_loss < best_loss:
            # WandB - Save model checkpoint and save a file to associate with the run
            torch.save(model.module.state_dict(), 'v19_K036U175_8in_32out_best_loss.h5')
        # Update best losses
        best_loss = np.min([best_loss,this_train_loss])
        if this_val_loss < best_val_loss:
            # WandB - Save model checkpoint and save a file to associate with the run
            torch.save(model.module.state_dict(), 'v19_K036U175_8in_32out_best_loss_val.h5')
            print('Epoch =',epoch,', Saved loss =',np.round(this_train_loss,6),', Saved val loss =',np.round(this_val_loss,6), 'Epoch Time in sec =',epoch_time)
        else:
            print('Epoch =',epoch,', Loss =',np.round(this_train_loss,6),', Val loss =',np.round(this_val_loss,6), 'Epoch Time in sec =',epoch_time)
        # Update best val losses
        best_val_loss = np.min([best_val_loss,this_val_loss])
    print('Best loss: ',np.round(best_loss,4))
    print('Best val loss: ',np.round(best_val_loss,4))
if __name__ == '__main__':
    main()

Complete log:

(bryan_env) CUDA_VISIBLE_DEVICES=0,1 python v19_K036U175.py
GPU device count =  2
GPUs being used =  Tesla V100-SXM2-32GB
wandb: Tracking run with wandb version 0.12.16
wandb: W&B syncing is set to `offline` in this directory.  
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
Traceback (most recent call last):
  File "/ifshome/bmoore/v19_K036U175.py", line 348, in <module>
    main()
  File "/ifshome/bmoore/v19_K036U175.py", line 328, in main
    this_train_loss = train(wb_config, model, device, train_loader, optimizer, epoch)
  File "/ifshome/bmoore/v19_K036U175.py", line 263, in train
    output = model(data)
  File "/ifshome/bmoore/.conda/envs/bryan_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1128, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/ifshome/bmoore/.conda/envs/bryan_env/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/ifshome/bmoore/.conda/envs/bryan_env/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/ifshome/bmoore/.conda/envs/bryan_env/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/ifshome/bmoore/.conda/envs/bryan_env/lib/python3.9/site-packages/torch/_utils.py", line 457, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/ifshome/bmoore/.conda/envs/bryan_env/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/ifshome/bmoore/.conda/envs/bryan_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/ifshome/bmoore/v19_K036U175.py", line 204, in forward
    x = self.conv0(x)
  File "/ifshome/bmoore/.conda/envs/bryan_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/ifshome/bmoore/.conda/envs/bryan_env/lib/python3.9/site-packages/bayesian_torch/layers/variational_layers/conv_variational.py", line 341, in forward
    out = F.conv2d(input, weight, bias, self.stride, self.padding,
RuntimeError: Calculated padded input size per channel: (256 x 8). Kernel size: (32 x 32). Kernel size can't be greater than actual input size
msubedar commented 1 year ago

@bryanmooremd Could you please check with bayesian-torch==0.4.0? Arbitrary kernel sizes are supported in this version.

bryanmooremd commented 1 year ago

@msubedar Switching from 0.2.1 to 0.4.0 resolved the issue as you suggested. Thank you. Not sure why a pip install bayesian-torch installed 0.2.1 several weeks ago.