yongchaoz / FRePo

Official Code for Dataset Distillation using Neural Feature Regression (NeurIPS 2022)
45 stars 10 forks source link

Would you like to give scripts or hyper-parameter table for other settings? #3

Closed zhaoguangxiang closed 1 year ago

zhaoguangxiang commented 2 years ago

By directly using your script for 1 IPC of CIFAR100, I got 27.7 acc (+0.5 compared to the paper). However, I only got 39.9 (-1.4) acc on 10 IPC of cifaCIFAR100100 when I directly use the hyperparameters of 1 IPC of CIFAR100. For 50 IPC of CIFAR100, I got "nan", below is the log:

INFO:absl:Saved checkpoint at train_log/cifar100/step500K_num5000/conv_w128_d3_batch_llTrue/state10_reset100/saved_ckpt/checkpoint_1 INFO:absl:[500] monitor/learning_rate=0.0002999993448611349, monitor/steps_per_second=1.837596, proto/x_proto_norm=nan, proto/y_proto_margin_max=nan, proto/y_proto_margin_mean=nan, rad_norm_x=nan, train/grad_norm_y=nan, train/kernel_loss=nan, train/label_loss=nan, train/top5accuracy=0.049998436123132706, train/total_loss=nan

yongchaoz commented 2 years ago

Hi Guangxiang,

Good catch. There is one argument called use_flip which I do not set it correctly for the cifar100 setting. This argument determines whether to flip the images when doing the meta-gradient computation. Empirically, it works well when the distilled size is small but hurts the performance when the size is large. I just pushed the new change to the master branch. and it should be fine now.

I notice that your output prints rad_norm_x=nan, while it supposed to be train/grad_norm_x=nan. Is this the typo you made in your script?

Best, Yongchao

fiona-lxd commented 1 year ago

Actually, I have git the latest repo and run cifar100 with 50 IPC, the loss still went nan as follows: INFO:absl:[1500] monitor/learning_rate=0.0002999940188601613, monitor/steps_per_second=4.218012, proto/x_proto_norm=nan, proto/y_proto_margin_max=nan, proto/y_proto_margin_mean=nan, proto/y_proto_margin_min=nan, proto/y_proto_max_max=nan, proto/y_proto_max_mean=nan, proto/y_proto_max_min=nan, proto/y_proto_norm=nan, train/accuracy=0.009988280944526196, train/grad_norm_x=nan, train/grad_norm_y=nan, train/kernel_loss=nan, train/label_loss=nan, train/top5accuracy=0.050001952797174454, train/total_loss=nan

I have checked the use_flip and found that it has been set to "False" as you suggested. Could you please provide some help with this situation?

yongchaoz commented 1 year ago

Hi Fiona,

Could you show me the command you are running? Does the problem go away when you use a different random seed? If you are using "conv" arch, here are two implementations which differs in where the normalization layer is. "Implementation 1" seems to have better performance on Imagenette, while "Implementation 2" seems to be better for CIFAR. I do not have a good explanation for this observation, but it may potentially solve your problem.

Implementation 1 (Current One).

for i in range(self.depth):
    if i != 0 and self.normalization != 'identity':
        x = norm_layer()(x)

    if i == 0 and channel == 1:
        pad = (self.kernel_size[0] // 2 + 2, self.kernel_size[0] // 2 + 2)
    else:
        pad = (self.kernel_size[0] // 2, self.kernel_size[0] // 2)

    x = nn.Conv(features=self.width * (2 ** i), kernel_size=self.kernel_size,
                padding=(pad, pad), use_bias=True, dtype=self.dtype)(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, (2, 2), strides=(2, 2))

Implementation 2 (First Commit).

for i in range(self.depth):
    if i == 0 and channel == 1:
        pad = (self.kernel_size[0] // 2 + 2, self.kernel_size[0] // 2 + 2)
    else:
        pad = (self.kernel_size[0] // 2, self.kernel_size[0] // 2)

    x = nn.Conv(features=self.width * (2 ** i), kernel_size=self.kernel_size,
                padding=(pad, pad), use_bias=True, dtype=self.dtype)(x)
    if not self.normalization == 'identity':
        x = norm_layer()(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, (2, 2), strides=(2, 2))

Best, Yongchao

fiona-lxd commented 1 year ago

Thank you for your quick reply. The command is listed as follows. I only changed the dataset_name and num_prototypes_per_class to cifar100 and 50 respectively.

export XLA_FLAGS="--xla_gpu_strict_conv_algorithm_picker=false"

path="--dataset_name=cifar100 --train_log=train_log --train_img=train_img --zca_path=data/zca --data_path=./data/tensorflow_datasets --save_image=True" exp="--learn_label=True --random_seed=0" arch="--arch=conv --width=128 --depth=3 --normalization=batch" hyper="--max_online_updates=100 --num_nn_state=10 --num_train_steps=500000" python -m script.distill $path $exp $arch $hyper --num_prototypes_per_class=50

I have tested with 3 seeds and have also tried the two implementations as you suggested. However, it always returns the "nan" loss. Besides, I have run the experiments these days and found the results on TinyImageNet with ipc=1 is only 10.35+-0.53 while it is 15.4+-0.3 in your paper. The command is the same as the above one except for the dataset_name (=tiny_imagenet) and num_prototypes_per_class (=1). Could you offer some help to figure out what I did wrong?

yongchaoz commented 1 year ago

For the TinyImageNet, you may want to change the depth to 4 to account for the increase in resolution.

For the Nan problem, I did not encounter the same issue on my end. Could you show me a more complete version of your training log from step 0.

fiona-lxd commented 1 year ago

Thank you for your suggestion. I will try the depth=4 for TinyImageNet.

For the Nan problem, below is the log: INFO:absl:Load dataset info from ./data/tensorflow_datasets/cifar100/3.0.2 INFO:absl:Reusing dataset cifar100 (./data/tensorflow_datasets/cifar100/3.0.2) INFO:absl:Load from data/zca/cifar100_normalize_zca.npz! INFO:absl:Constructing tf.data.Dataset cifar100 for split ['train', 'test'], from ./data/tensorflow_datasets/cifar100/3.0.2 INFO:absl:Dataset size: 50000 WARNING:tensorflow:From /data/xiangxi/env/FRePo/conda/envs/frepo/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 WARNING:tensorflow:From /data/xiangxi/env/FRePo/conda/envs/frepo/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 Process the data: 100%|████████████| 10/10 [00:05<00:00, 2.00it/s] INFO:absl:Dataset size: 10000 Process the data: 100%|██████████████| 2/2 [00:00<00:00, 2.16it/s] INFO:absl:Resolution: 32 INFO:absl:Proto Scale: {'x_proto': Array(55.425625, dtype=float32, weak_type=True)} INFO:absl:Working directory: train_log_th_06/cifar100/step500K_num5000/conv_w128_d3_batch_llTrue/state10_reset100 INFO:absl:image directory: train_img_th_06/cifar100/step500K_num5000/conv_w128_d3_batch_llTrue/state10_reset100 INFO:absl:Found no checkpoint directory at train_log_th_06/cifar100/step500K_num5000/conv_w128_d3_batch_llTrue/state10_reset100/proto INFO:absl:Initial compilation, this might take some minutes... /data/xiangxi/codes/2022/Governance_for_Generated_Images/DDPM_for_density/dataset_distillation/FRePo.bak/lib/datadistillation/frepo.py:111: FutureWarning: The sym_pos argument to solve() is deprecated and will be removed in a future JAX release. Use assume_a='pos' instead. pred = jnp.dot(k_tp, sp.linalg.solve(k_pp_reg, y_proto, sym_pos=True)) INFO:absl:Initial compilation completed. Elapsed: 19.851025819778442s INFO:absl:[1] monitor/learning_rate=0.0003000000142492354, monitor/steps_per_second=4619.888091, proto/x_proto_norm=nan, proto/y_proto_margin_max=nan, proto/y_proto_margin_mean=nan, proto/y_proto_margin_min=nan, proto/y_proto_max_max=nan, proto/y_proto_max_mean=nan, proto/y_proto_max_min=nan, proto/y_proto_norm=nan, train/accuracy=0.0087890625, train/grad_norm_x=nan, train/grad_norm_y=nan, train/kernel_loss=nan, train/label_loss=-0.009999999776482582, train/top5accuracy=0.0517578125, train/total_loss=nan Train on distilled data: 5000it [01:30, 55.56it/s] Evaluate using nn predictor: 10it [00:01, 8.10it/s] INFO:absl:{'/nn_accuracy_mean': 0.01, '/nn_loss_mean': nan, '/nn_top5accuracy_mean': 0.050000004} Train on distilled data: 5000it [01:24, 59.34it/s] Evaluate using nn predictor: 10it [00:00, 83.14it/s] INFO:absl:{'/nn_accuracy_mean': 0.01, '/nn_loss_mean': nan, '/nn_top5accuracy_mean': 0.050000004} Train on distilled data: 5000it [01:24, 58.97it/s] Evaluate using nn predictor: 10it [00:00, 92.36it/s] INFO:absl:{'/nn_accuracy_mean': 0.01, '/nn_loss_mean': nan, '/nn_top5accuracy_mean': 0.050000004} Train on distilled data: 5000it [01:24, 59.07it/s] Evaluate using nn predictor: 10it [00:00, 72.35it/s] INFO:absl:{'/nn_accuracy_mean': 0.01, '/nn_loss_mean': nan, '/nn_top5accuracy_mean': 0.050000004} Train on distilled data: 5000it [01:24, 59.36it/s] Evaluate using nn predictor: 10it [00:00, 62.76it/s] INFO:absl:{'/nn_accuracy_mean': 0.01, '/nn_loss_mean': nan, '/nn_top5accuracy_mean': 0.050000004} INFO:absl:[1] eval/step_acc_mean=1.000000 INFO:absl:[1] eval/step_std=0.000000 INFO:absl:Saving checkpoint at step: 1 INFO:absl:Saved checkpoint at train_log_th_06/cifar100/step500K_num5000/conv_w128_d3_batch_llTrue/state10_reset100/saved_ckpt/checkpoint_1 INFO:absl:[500] monitor/learning_rate=0.0002999993448611349, monitor/steps_per_second=0.892568, proto/x_proto_norm=nan, proto/y_proto_margin_max=nan, proto/y_proto_margin_mean=nan, proto/y_proto_margin_min=nan, proto/y_proto_max_max=nan, proto/y_proto_max_mean=nan, proto/y_proto_max_min=nan, proto/y_proto_norm=nan, train/accuracy=0.009990684688091278, train/grad_norm_x=nan, train/grad_norm_y=nan, train/kernel_loss=nan, train/label_loss=nan, train/top5accuracy=0.049998436123132706, train/total_loss=nan

yongchaoz commented 1 year ago

I just run the same setting. I attached my log below. It seems quite surprise to me that you get nan in the first step and even for the x_proto_norm, and y_proto_norm. In this case, NAN may not be caused by any gradient computation. Something weird is going on in your initialization. Could you double check you can run cifar100 with 1 or 10 IPC successfully?

INFO:absl:Constructing tf.data.Dataset cifar100 for split ['train', 'test'], from /tensorflow_datasets/cifar100/3.0.2 INFO:absl:Dataset size: 50000 Process the data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:10<00:00, 1.03s/it] INFO:absl:Dataset size: 10000 Process the data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.19it/s] INFO:absl:Resolution: 32 INFO:absl:Proto Scale: {'x_proto': DeviceArray(55.425625, dtype=float32, weak_type=True)} INFO:absl:Working directory: train_log/cifar100/step500K_num5000/conv_w128_d3_batch_llTrue/state10_reset100 INFO:absl:image directory: train_img/cifar100/step500K_num5000/conv_w128_d3_batch_llTrue/state10_reset100 INFO:absl:Found no checkpoint directory at train_log/cifar100/step500K_num5000/conv_w128_d3_batch_llTrue/state10_reset100/proto INFO:absl:Initial compilation, this might take some minutes... /Project/FRePo/lib/datadistillation/frepo.py:111: FutureWarning: The sym_pos argument to solve() is deprecated and will be removed in a future JAX release. Use assume_a='pos' instead. pred = jnp.dot(k_tp, sp.linalg.solve(k_pp_reg, y_proto, sym_pos=True)) 2022-12-28 22:16:20.296328: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:729] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms. Conv: (f32[5000,16,16,128]{2,1,3,0}, u8[0]{0}) custom-call(f32[5000,16,16,256]{2,1,3,0}, f32[3,3,128,256]{1,0,2,3}), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBackwardInput", backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" INFO:absl:Initial compilation completed. Elapsed: 24.816324949264526s INFO:absl:[1] monitor/learning_rate=0.0003000000142492354, monitor/steps_per_second=559.485215, proto/x_proto_norm=25.78557777404785, proto/y_proto_margin_max=0.31622788310050964, proto/y_proto_margin_mean=0.3162107765674591, proto/y_proto_margin_min=0.3162081837654114, proto/y_proto_max_max=0.3130752742290497, proto/y_proto_max_mean=0.3130582571029663, proto/y_proto_max_min=0.31305569410324097, proto/y_proto_norm=0.31463536620140076, train/accuracy=0.18359375, train/grad_norm_x=0.019630448892712593, train/grad_norm_y=0.07783296704292297, train/kernel_loss=0.5940134525299072, train/label_loss=-0.009999999776482582, train/top5accuracy=0.294921875, train/total_loss=0.5840134620666504

fiona-lxd commented 1 year ago

The experiments for Cifar100 with ipc=1 and ipc=10 can be successfully conducted. The results are 27.85+-0.16 and 41.13+-0.39 respectively while they are 28.7+-0.1 and 42.5+-0.2 in your paper.

I have started again by git clone your code and running the above command and still got 'Nan'.

yongchaoz commented 1 year ago

Honestly, I have no idea why that is the case. Maybe you have a different Jax version and CUDA version from me? But i guess it may not be the issue.

One thing I notice. I have XLA_FLAGS=--xla_gpu_cuda_data_dir=/scratch/ssd001/pkgs/cuda-11.3 instead of export XLA_FLAGS="--xla_gpu_strict_conv_algorithm_picker=false".

Since I cannot reproduce the error on my end, I may not be able to provide more help. But I am happy to learn from you if you find the reason.

Good luck, Yongchao

fiona-lxd commented 1 year ago

My CUDA version is 11.3 while the jax version is 0.4.1. I have updated the uncommand the XLA_FLAGS="--xla_gpu_strict_conv_algorithm_picker=false" but still get Nan. I noticed that the person who opened this issue seems to have the same problem as me (proto/x_proto_norm=nan), therefore, I think this may be a common issue to reproduce your results. Maybe you can try to git this repo and run the command to reproduce the error? If the Nan problem doesn’t occur, then I think it is caused by environmental differences.
And here is my environment: cuda 11.3 tensorflow 2.11.0 tensorflow_datasets 4.7.0 jax 0.4.1 Could you provide yours? What's the cudnn version in your environments ?

yongchaoz commented 1 year ago

Yes. I do clone again to test but I cannot reproduce the Nan problem. Here are my corresponding package versions.

cuda-11.3 tensorflow 2.10.0 tensorflow-datasets 4.7.0 jax 0.3.21
jaxlib 0.3.20+cuda11.cudnn82 cudnn-11.4-v8.2.4.15

nvidia-smi NVIDIA-SMI 470.141.03 Driver Version: 470.141.03 CUDA Version: 11.4

fiona-lxd commented 1 year ago

I found this is caused by different GPU devices. Under the same docker environment, On 3090, it goes 'NaN' while on 'V100' it won't. BTW, I found a strange thing in the "script/eval.py". The "normalization" is always set to be "identity". When I tried to reproduce the results in Table 15("Cross-architecture transfer performance on CIFAR10 with 10 Img/Cls"). The results for models with BN are not as good as you reported by setting the "normalization" and "has_bn" as "batch" and True in "script/eval.py" respectively. For example, the results of imgs trained on Conv-BN and transferred for Vgg-BN is 54.4 while it is 59.4 in your paper. So does the Conv-BN, whose performance is only 57.3. Can you offer some help for this?

yongchaoz commented 1 year ago

It's good to know that is caused by different GPU devices. I have tested various GPU devices in the past, including the P100, T4, V100, RTX6000, and A100, and have not encountered any issues. I am a bit surprised that 3090 behaves differently.

For the evaluation, I use the "identity" on purpose (See Appendix A.1 Model section). All of my "conv" results were evaluated using "identity", unless there was a "BN" flag present. I made this design choice because I found that adding batch normalization improves the learning of distilled images, but negatively impacts evaluation performance. This can be seen in Table 12, 13, and 14. While I did not thoroughly investigate the cause, it is possible that using the batch statistics of the distilled data may be the issue. Alternatively, using the statistics of the original data, as suggested in the DSA/DM paper, may be a better approach. Additionally, it may be necessary to adjust the learning rate when using different normalizations. A good starting point for batch normalization is 0.01.

Best, Yongchao