numpee / CKA.pytorch

A PyTorch implementation of Centered Kernel Alignment (CKA) with GPU acceleration.
Apache License 2.0
35 stars 1 forks source link

NaNs #2

Open repers opened 1 year ago

repers commented 1 year ago

Hi there, thanks for providing the code for the CKA analysis. I have tried implementing this on my model, however, I keep on getting NaNs in the final output matrix. Any idea why this happens? Thanks

numpee commented 1 year ago

There could be many reasons for NaN, but more details are needed

repers commented 1 year ago

Hi, so the network I'm trying to apply this on is https://github.com/HyeongminLEE/AdaCoF-pytorch/blob/master/models/adacofnet.py

I use the dataloader to get the testeset from https://github.com/tding1/CDFI/blob/main/datasets.py

I have a feeling it might be due to the custom cuda implementation at the end, but is there a way to only apply hooks for the UNet part of the architecture?

This is the main code I run:

from datareader import DBreader_Vimeo90k
from torch.utils.data import DataLoader
import argparse
from torchvision import transforms
import torch
from TestModule import Middlebury_other
import models
from trainer import Trainer
import losses
import datetime
from adacofnet1 import make_model
from datasets import Vimeo90K_interp
from cka import CKACalculator
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser(description='AdaCoF-Pytorch')

# parameters
# Model Selection
parser.add_argument('--model', type=str, default='adacofnet')

# Hardware Setting
parser.add_argument('--gpu_id', type=int, default=0)

# Directory Setting
parser.add_argument('--train', type=str, default='./db/vimeo_triplet')
parser.add_argument('--out_dir', type=str, default='./output_adacof_train')
parser.add_argument('--load', type=str, default=None)
parser.add_argument('--load2', type=str, default=None)
parser.add_argument('--test_input', type=str, default='./test_input/middlebury_others/input')
parser.add_argument('--gt', type=str, default='./test_input/middlebury_others/gt')

# Learning Options
parser.add_argument('--epochs', type=int, default=50, help='Max Epochs')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
parser.add_argument('--loss', type=str, default='1*Charb+0.01*g_Spatial+0.005*g_Occlusion', help='loss function configuration')
parser.add_argument('--patch_size', type=int, default=128, help='Patch size')

parser.add_argument('--kernel_size', type=int, default=5)
parser.add_argument('--dilation', type=int, default=1)

transform = transforms.Compose([transforms.ToTensor()])

def main():
    args = parser.parse_args()
    torch.cuda.set_device(args.gpu_id)
    train_dataset, val_dataset = Vimeo90K_interp(args.train)
    test_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
    model1 = make_model(args)
    checkpoint = torch.load(args.load)
    model1.load_state_dict(checkpoint['state_dict'])
    model2 = make_model(args)
    checkpoint2 = torch.load(args.load2)
    model2.load_state_dict(checkpoint2['state_dict'])
    calculator = CKACalculator(model1=model1, model2=model2, dataloader=test_loader)
    cka_output = calculator.calculate_cka_matrix()
    print(cka_output)
    import matplotlib.pyplot as plt
    plt.rcParams['figure.figsize'] = (7, 7)
    plt.savefig('new.png')
    for i, name in enumerate(calculator.module_names_X):
        print(name)
if __name__ == "__main__":
    main()
repers commented 1 year ago

One thing to note, the size of the testset does not matter in terms of getting NaNs, I tried a small subset of about 40 images and the full one, same issue happened!

numpee commented 1 year ago

Hi, you can apply hooks to custom layers by passing the modules into the CKA calculator. Check under "Advanced Usage" of the example jupyter provided.

As for the NaNs, I'm not exactly sure what's going on. There may be some underflow or overflow. Maybe you could try modifying the epsilon parameter in the CKACalculator? Also, do you normalize the input data before passing it to the model?

WenLinLliu commented 3 months ago

batchsize must be gratter than 3