Haian-Jin / Neural_Gaffer

[NeurIPS 2024] Official code for "Neural Gaffer: Relighting Any Object via Diffusion"
230 stars 5 forks source link

Input Layer Weight Initialization in UNet for Fine-Tuning #6

Open BaiYeBuTingXuan opened 2 days ago

BaiYeBuTingXuan commented 2 days ago

In the file neural_gaffer_inference_real_data.py, lines 171-179, I noticed that the input dimension of the UNet is changed from 8 to 16 channels during the inference phase. The code copies the weights from the first 8 layers and initializes the last 8 layers' weights to zero. However, I haven't found any corresponding code for fine-tuning that adjusts the input layer's weights, which are set to zero. Could you point it out?

Here’s the relevant code snippet for reference:


# Zero initialize UNet conv_in from 8 channels to 16 channels
conv_in_16 = torch.nn.Conv2d(16, unet.conv_in.out_channels, kernel_size=unet.conv_in.kernel_size, padding=unet.conv_in.padding)
conv_in_16.requires_grad_(False)
unet.conv_in.requires_grad_(False)
torch.nn.init.zeros_(conv_in_16.weight)
conv_in_16.weight[:, :8, :, :].copy_(unet.conv_in.weight)
conv_in_16.bias.copy_(unet.conv_in.bias)
unet.conv_in = conv_in_16
unet.requires_grad_(False)
Haian-Jin commented 1 day ago

Our model was trained from Zero123's checkpoint, so the initial UNet was 8. We then load the our own checkpoint herecode, which will load our pre-trained model weight.