NUS-HPC-AI-Lab / Neural-Network-Parameter-Diffusion

We introduce a novel approach for parameter generation, named neural network parameter diffusion (p-diff), which employs a standard latent diffusion model to synthesize a new set of parameters
787 stars 38 forks source link

Diffusion model setting when trained on whole network #13

Open FelixFeiyu opened 3 months ago

FelixFeiyu commented 3 months ago

Hi,

I tried to use a three-layer CNN on CIFAR10 to reproduce the work like what the paper mentions.I chose the autoencoder.Latent_AE_cnn_big as ae_model in the ae_ddpm.yaml. The classification accuracy of the reconstructed CNN can achieve a comparable level 79%. However, when starting training the diff-network, the accuracy drops to 10%.

Is the diff-model I used or other setting correct?

name: ae_ddpm
ae_model:
  _target_: core.module.modules.autoencoder.Latent_AE_cnn_big
  in_dim: 39882 #2048

model:
  arch:
    _target_: core.module.wrapper.ema.EMA
    model:
      _target_: core.module.modules.od_unet.AE_CNN_bottleneck
      in_dim: 52

beta_schedule:
  start: 1e-4
  end: 2e-2
  schedule: linear
  n_timestep: 1000

model_mean_type: eps
model_var_type: fixedlarge
loss_type: mse

train:
  split_epoch: 30000
  optimizer:
    _target_: torch.optim.AdamW
    lr: 1e-3
    weight_decay: 2e-6

  ae_optimizer:
    _target_: torch.optim.AdamW
    lr: 1e-3
    weight_decay: 2e-6

  lr_scheduler:

Thank you

1zeryu commented 3 months ago

"When starting training the diff-network, the accuracy drops to 10%." The accuracy of 10% is the p-diff generated model accuracy? or the Autoencoder reconstructed model. When you starting training the diff-network, we only test the diffusion generated model but AE reconstructed model.

def validation_step(self, batch, batch_idx, **kwargs: Any):
        batch = self.pre_process(batch)
        outputs = self.generate(batch, 10)

        params = self.post_process(outputs)
        params = params.cpu()

        accs = []
        for i in range(params.shape[0]):
            param = params[i].to(batch.device)
            acc, test_loss, output_list = self.task_func(param)
            accs.append(acc)
        best_acc = np.max(accs)
        print("generated models accuracy:", accs)
        print("generated models mean accuracy:", np.mean(accs))
        print("generated models best accuracy:", best_acc)
        self.log('best_g_acc', best_acc)
        self.log('mean_g_acc', np.mean(accs).item())
        return {'best_g_acc': best_acc, 'mean_g_acc': np.mean(accs).item()}
FelixFeiyu commented 3 months ago

Thank you for your response.

AE reconstruction model accuracy can achieve nearly 78% in the first 30,000 training epochs. But in the last 30,000 training epochs, generated models best accuracy is 13%.

So does this mean the diff-model was not trained well?

1zeryu commented 3 months ago

It looks like the diff-model wasn't trained well. Have you try other networks? Or try a lower learning rate? If you have any further questions please let me know.

FelixFeiyu commented 3 months ago

What I am doing here is reproducing the experiment "Generalization on entire model parameters" in the paper. Do you plan to publish the related details and settings?

Thank you.

waldun-m commented 3 months ago

I tried to reproduce the same experiment, It actually take something like 300,000 epoch for LDM training for 40k in_dim to produce reasonable result.

JoycexxZ commented 3 months ago

I am also trying to reproduce this experiment, but I can't even successfully reconstruct the weights using VAE, have you changed any settings of VAE training/task training? When does the reconstruction accuracy reach ~80%?

waldun-m commented 3 months ago

VAE training is rather eaiser, just default hyperparams, maybe tune lr and noise factor. I notice that the default latent niose factor in ae_ddpm.yaml is 0.5, which I consider too large. I train a VAE with 40k in_dim(all parameters of a 4 layers CNN) case, the reconstruction accuracy reach the level close to input parameters at about epoch 25,000.

JoycexxZ commented 3 months ago

Hi, I have tried several lr/noise settings but still can not reproduce the result. For parameter dataset generation, do you use different initialization for 200 model checkpoints? If so, can I kindly ask for the specific settings like lr/scheduler/noise you used in the experiment?

FelixFeiyu commented 3 months ago

Hi, I have tried several lr/noise settings but still can not reproduce the result. For parameter dataset generation, do you use different initialization for 200 model checkpoints? If so, can I kindly ask for the specific settings like lr/scheduler/noise you used in the experiment?

I use autoencoder.Latent_AE_cnn_big, changing its inputs dim and the optimizer learning rate 1e-2.

FelixFeiyu commented 3 months ago

VAE training is rather eaiser, just default hyperparams, maybe tune lr and noise factor. I notice that the default latent niose factor in ae_ddpm.yaml is 0.5, which I consider too large. I train a VAE with 40k in_dim(all parameters of a 4 layers CNN) case, the reconstruction accuracy reach the level close to input parameters at about epoch 25,000.

Hi, I have another question that when I test the trained model directly, I cannot get the same results as the one in test (following completed training). The results are seem to be from a initial model instead of a trained one. Did you meet the same question?

waldun-m commented 3 months ago

VAE training is rather eaiser, just default hyperparams, maybe tune lr and noise factor. I notice that the default latent niose factor in ae_ddpm.yaml is 0.5, which I consider too large. I train a VAE with 40k in_dim(all parameters of a 4 layers CNN) case, the reconstruction accuracy reach the level close to input parameters at about epoch 25,000.

Hi, I have another question that when I test the trained model directly, I cannot get the same results as the one in test (following completed training). The results are seem to be from a initial model instead of a trained one. Did you meet the same question?

Yeah. This is the problem I am dealing with, but no progress for a few days.

waldun-m commented 3 months ago

Hi, I have tried several lr/noise settings but still can not reproduce the result. For parameter dataset generation, do you use different initialization for 200 model checkpoints? If so, can I kindly ask for the specific settings like lr/scheduler/noise you used in the experiment?

Shouldn't it be: Train first x epochs till almost converge -> Train y(200 as default) more epochs for params data? I use default seed, lr=1e-3, default lr_scheduler, both noise factors set to 1e-3

JoycexxZ commented 3 months ago

Hi, I have tried several lr/noise settings but still can not reproduce the result. For parameter dataset generation, do you use different initialization for 200 model checkpoints? If so, can I kindly ask for the specific settings like lr/scheduler/noise you used in the experiment?

Shouldn't it be: Train first x epochs till almost converge -> Train y(200 as default) more epochs for params data? I use default seed, lr=1e-3, default lr_scheduler, both noise factors set to 1e-3

In the paper, there is a sentence "Different from the aforementioned training data collection strategy, we individually train these architectures from scratch with 200 different random seeds" in "Generalization on entire model parameters" part. I think this means that for parameter generation part, 200 models as default should be trained respectively from random initialization. I also find that using 200 model checkpoints from one initialization will cause the diffusion & the vae model to severely overfit, as all the training samples are very similar.

SKDDJ commented 3 months ago

Hi, I have tried several lr/noise settings but still can not reproduce the result. For parameter dataset generation, do you use different initialization for 200 model checkpoints? If so, can I kindly ask for the specific settings like lr/scheduler/noise you used in the experiment?

Shouldn't it be: Train first x epochs till almost converge -> Train y(200 as default) more epochs for params data? I use default seed, lr=1e-3, default lr_scheduler, both noise factors set to 1e-3

Hello @waldun-m @FelixFeiyu @JoycexxZ , could you kindly share some information about how to config your training datasets with me? I am particularly interested in the details regarding the model parameters, specifically the shape of the flattened 1-dimensional tensor. What length does your input dimension have? I'm curious to understand how the input dimension influences the configuration of the autoencoder and diffusion models concerning model layers and channels. Thanks.

SKDDJ commented 3 months ago

Hi, I have tried several lr/noise settings but still can not reproduce the result. For parameter dataset generation, do you use different initialization for 200 model checkpoints? If so, can I kindly ask for the specific settings like lr/scheduler/noise you used in the experiment?

Shouldn't it be: Train first x epochs till almost converge -> Train y(200 as default) more epochs for params data? I use default seed, lr=1e-3, default lr_scheduler, both noise factors set to 1e-3

Hello @waldun-m @FelixFeiyu @JoycexxZ , could you kindly share your training datasets with me? I am particularly interested in the details regarding the model parameters, specifically the shape of the flattened 1-dimensional tensor. What length does your input dimension have? I'm curious to understand how the input dimension influences the configuration of the autoencoder and diffusion models concerning model layers and channels. Thanks.

For example, I aim to train a NND to generate a complete model with nearly ten million (or potentially even more) parameters. How can I configure the autoencoder and diffusion model to better suit these parameters during training? Are there any existing scaling laws for parameter generation tasks like LLM?

FelixFeiyu commented 2 months ago

VAE training is rather eaiser, just default hyperparams, maybe tune lr and noise factor. I notice that the default latent niose factor in ae_ddpm.yaml is 0.5, which I consider too large. I train a VAE with 40k in_dim(all parameters of a 4 layers CNN) case, the reconstruction accuracy reach the level close to input parameters at about epoch 25,000.

Hi, I have another question that when I test the trained model directly, I cannot get the same results as the one in test (following completed training). The results are seem to be from a initial model instead of a trained one. Did you meet the same question?

Yeah. This is the problem I am dealing with, but no progress for a few days.

The problem seems that the trained network used in the diffusion model is not stored successively because there are no related parameters when I checked the output .cpkt file. Therefore after adding the codes to store the network after training and to load it before testing, running the test part directly can achieve the same result. However, I didn't find why the network cannot be stored in the original code automatically.

DrinkWaterOrange commented 2 months ago

VAE training is rather eaiser, just default hyperparams, maybe tune lr and noise factor. I notice that the default latent niose factor in ae_ddpm.yaml is 0.5, which I consider too large. I train a VAE with 40k in_dim(all parameters of a 4 layers CNN) case, the reconstruction accuracy reach the level close to input parameters at about epoch 25,000.

Hi, I have another question that when I test the trained model directly, I cannot get the same results as the one in test (following completed training). The results are seem to be from a initial model instead of a trained one. Did you meet the same question?

Yeah. This is the problem I am dealing with, but no progress for a few days.

The problem seems that the trained network used in the diffusion model is not stored successively because there are no related parameters when I checked the output .cpkt file. Therefore after adding the codes to store the network after training and to load it before testing, running the test part directly can achieve the same result. However, I didn't find why the network cannot be stored in the original code automatically.

Hi~ Can you share about how to store and load network in a proper way? Thanks a lot.

rdf87 commented 1 month ago

Hi,

I tried to use a three-layer CNN on CIFAR10 to reproduce the work like what the paper mentions.I chose the autoencoder.Latent_AE_cnn_big as ae_model in the ae_ddpm.yaml. The classification accuracy of the reconstructed CNN can achieve a comparable level 79%. However, when starting training the diff-network, the accuracy drops to 10%.

Is the diff-model I used or other setting correct?

name: ae_ddpm
ae_model:
  _target_: core.module.modules.autoencoder.Latent_AE_cnn_big
  in_dim: 39882 #2048

model:
  arch:
    _target_: core.module.wrapper.ema.EMA
    model:
      _target_: core.module.modules.od_unet.AE_CNN_bottleneck
      in_dim: 52

beta_schedule:
  start: 1e-4
  end: 2e-2
  schedule: linear
  n_timestep: 1000

model_mean_type: eps
model_var_type: fixedlarge
loss_type: mse

train:
  split_epoch: 30000
  optimizer:
    _target_: torch.optim.AdamW
    lr: 1e-3
    weight_decay: 2e-6

  ae_optimizer:
    _target_: torch.optim.AdamW
    lr: 1e-3
    weight_decay: 2e-6

  lr_scheduler:

Thank you

Hi~ May I ask if the problem you encountered has been resolved? I plan to use MNIST to reconstruct the parameters of LeNet5. In the first 30000 epochs, we achieved good accuracy, but in the last 30000 epochs, the accuracy plummeted to 10%. Then we tried to generate the weights only of the layer fc3 in the LeNet5 network, which can improve the accuracy to 90%. That is to say, I cannot reproduce the experiment of generating the entire network in the author's article. In addition, according to the data in Table 3 of the article, the number of parameters of the LeNet5 network should be within the generation capability range of the diffusion model. Did the author not disclose the correct diffusion model he used, or did the data in the article have some issues?