TyXe-BDL / TyXe

MIT License
144 stars 33 forks source link

ValueError: Error while computing log_prob at site 'likelihood.data' #11

Closed BenCretois closed 2 years ago

BenCretois commented 2 years ago

Hi,

First, I want to say that the work you've done is clearly amazing.

I have playing a bit with TyXe and I am trying to convert the FC layer of a pretrained ResNet to a probabilistic layer.

I have been implementing the code shown in this paper, unfortunately I was unsuccessful. When using a homoskedastic Gaussian for the likelihood function I get this error:

ValueError: Error while computing log_prob at site 'likelihood.data':
Value is not broadcastable with batch_shape+event_shape: torch.Size([128]) vs torch.Size([128, 2]).

...

Sample Sites:                      
              net.fc.weight dist             |   2 512
                           value             |   2 512
                        log_prob             |        
                net.fc.bias dist             |   2    
                           value             |   2    
                        log_prob             |        
            likelihood.data dist             | 128   2
                           value             | 128    

My code is very minimalistic:

import torch
import torch.nn as nn
import albumentations as A

import pyro
from pyro.distributions import Normal

from torch.utils.data import DataLoader

from TyXe.tyxe.priors import IIDPrior
from TyXe.tyxe.likelihoods import HomoskedasticGaussian
from TyXe.tyxe import VariationalBNN

from utils.loader import ImgLoader

import glob

transform = A.Compose(
            [
                A.SmallestMaxSize(max_size=260),
                A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
                A.RandomCrop(height=224, width=224),
                A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
                A.RandomBrightnessContrast(p=0.5),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
        )

# Prepare the dataset
data_path = "/Data/train"
list_images = glob.glob(data_path + "/*.jpg")

trainset = ImgLoader(list_images, transform=transform)

trainLoader = DataLoader(trainset,
                        batch_size = 128, 
                        shuffle=True)

# Load pre-trained model ResNet and add FC layers on top
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True # Fix bug for Pytorch 1.9
NN = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True)
NN.to('cpu')

# Freeze the layers
for param in NN.parameters():
    param.requires_grad = False

NN.fc = nn.Linear(512,2)

ll_prior = IIDPrior(Normal(0, 1), expose_all=False, expose_modules=[NN.fc])
likelihood = HomoskedasticGaussian(len(list_images), event_dim=2, scale=1)
lr_guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal
bnn = VariationalBNN(NN, ll_prior, likelihood, lr_guide)
optim = pyro.optim.Adam({"lr": 1e-3})

# fit the model
bnn.fit(trainLoader, optim, 5)

Thanks for your help already!

BenCretois commented 2 years ago

Hi,

Problem is solved, I was using a Gaussian instead of a Categorical likelihood for my data. The model now can be trained and returns a coherent accuracy (~96% for only 3 epochs of training).

karalets commented 2 years ago

Thank you for letting us know. I’m happy to hear you resolved the issue and find tyxe useful!

On Wed, Feb 2, 2022 at 4:18 AM Benjamin Cretois @.***> wrote:

Hi,

Problem is solved, I was using a Gaussian instead of a Categorical likelihood for my data. The model now can be trained and returns a coherent accuracy (~96% for only 3 epochs of training).

— Reply to this email directly, view it on GitHub https://github.com/TyXe-BDL/TyXe/issues/11#issuecomment-1027881278, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAKWCL4E5ORSB5CHCW6TRP3UZEOHRANCNFSM5NJVQBOA . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

You are receiving this because you are subscribed to this thread.Message ID: @.***>

BenCretois commented 2 years ago

It's really useful indeed! I first attempted to convert my DL models with Pyro but was unsuccessful. TyXe makes it way easier. I made a notebook about the analysis I was doing - using the classic cat / dogs dataset. I could share it if you think that would be useful for others who would like to use TyXe :)

hpplyt commented 2 years ago

Glad you managed to work things out. We should make it explicit in the docs what likelihoods correspond to which loss functions, I'm realizing that it's probably not very clear at the moment, thanks for opening the issue.

If you could share a link to your notebook that would be great! I'm sure others would be grateful for seeing more example use cases. I'd be happy to start a collection of contributed examples at the bottom of the readme and just put the link there for now.

BenCretois commented 2 years ago

Here is the cleaned notebook: https://gitlab.com/nina-data/deepexperiments/-/blob/main/notebooks/bayesianfy.ipynb. I still need to add a bit more text but most of it is already there. It's currently on a Gitlab repo but it's no problem for me to place it on GitHub! And yes would be great to have a section of contributed examples, that would be very helpful!

DengyingFu commented 1 year ago

Error while computing log_prob at site 'obs': Value is not broadcastable with batch_shape+event_shape: torch.Size([32, 3, 128, 128]) vs torch.Size([32, 3, 128]). I have encountered a similar mistake as you. Do you know how to correct the code? For example, how to modify the size here. By the way, I can't seem to open your notebook