JJGO / hyperlight

Modular and intuitive Hypernetworks in Pytorch
Apache License 2.0
32 stars 3 forks source link

Batch Normalization support #10

Open Richienb opened 1 week ago

Richienb commented 1 week ago

Vries et al. 2017 condition a network by hypernetizing BatchNorm layers only.

Presently, specifying a BatchNorm2d causes ValueError: Fan in and fan out can not be computed for tensor with fewer than 2 dimensions at: https://github.com/JJGO/hyperlight/blob/a3e210812ba2b9aede34cdb0550e823241a11e42/hyperlight/hypernet/initialization.py#L45

Then, if I set init_independent_weights=False, I get AttributeError: Uninitialized External Parameter, please set the value first. Did you mean: '_data'?

paramrajpura commented 4 days ago

@Richienb Any luck hypernetizing batchnorm? Facing a similar problem when batch size is greater than 1. Works otherwise. @JJGO @lumosan could you please help resolve this?

Richienb commented 4 days ago

I did manage to get it working after fixing I mistake I made where I was passing a whole batch of inputs to the hypernetwork, then directly passing the output to the using_externals function.

Richienb commented 4 days ago

I think I would like a comment in the readme about needing to set init_independent_weights=False

paramrajpura commented 2 days ago

@Richienb Thank you for your response. I am sorry to distract you again. I was unclear about your response.

I am trying the following way: I use Hypernet over a CNN that uses batch norm. So even after setting init_independent_weights=False, I get a shape mismatch error when doing a forward pass with batch size>1 on HyperConvNet.

I am still not clear on what is going wrong here. Can you help me out?

class HyperConvNet(nn.Module):

def __init__(self):
    super().__init__()
    mainnet = MyCNN()
    modules = hl.find_modules_of_type(mainnet, [nn.Conv2d,nn.Linear,nn.BatchNorm2d])
    self.mainnet = hl.hypernetize(mainnet, modules=modules)
    parameter_shapes = self.mainnet.external_shapes()
    print("Total parameters predicted:",parameter_shapes)
    self.hypernet = hl.HyperNet(
        input_shapes={'h': (1452,)},
        output_shapes=parameter_shapes,
        hidden_sizes=[256,512],init_independent_weights=False,fc_kws={'dropout_prob':0.3}
    )

def forward(self, main_input, hyper_input):
    parameters = self.hypernet(h=hyper_input)
    # print(parameters['lin.weight'].shape)
    with self.mainnet.using_externals(parameters):
        prediction = self.mainnet(main_input)

    return prediction
Richienb commented 2 days ago

@paramrajpura Are you sure that in the input to forward(self, main_input, hyper_input), hyper_input is NOT a batch? That would result in parameters being a batch, and using_externals doesn't like that.

paramrajpura commented 1 day ago

It's a batch. I wanted to do a batch-forward pass. Is there a way to do it?

Richienb commented 1 day ago

Is there a way to do it?

I have been iterating over each item in the batch separately, and combining them at the end. I'm not sure if there is a way that lets you do it in parallel, but the way I have just described does work.

paramrajpura commented 10 hours ago

Yes, Even I do that with a loop, but it gets too time-consuming to train and test on a large dataset. Thank you so much @Richienb for the discussion!