GeorgeCazenavette / mtt-distillation

Official code for our CVPR '22 paper "Dataset Distillation by Matching Training Trajectories"
https://georgecazenavette.github.io/mtt-distillation/
Other
395 stars 55 forks source link

Negative LR #16

Closed cliangyu closed 2 years ago

cliangyu commented 2 years ago

Hi! Thank you for your great work.

When I was distilling with my own dataset, there was very large loss (iter = 0490) and negative learning rate.

Could you help me figure out what is happening here? What hyperparameters should be adjusted in such case? Can we implement anything in code to prevent negative LR?

Thank you!

Evaluate 5 random ConvNetD4, mean = 0.2429 std = 0.0080
-------------------------
[2022-08-14 00:29:04] iter = 0400, loss = 1.2390[2022-08-14 00:29:12] iter = 0410, loss = 1.3564
[2022-08-14 00:29:19] iter = 0420, loss = 1.5845
[2022-08-14 00:29:27] iter = 0430, loss = 0.9945
[2022-08-14 00:29:35] iter = 0440, loss = 1.4876
[2022-08-14 00:29:43] iter = 0450, loss = 1.0734
[2022-08-14 00:29:51] iter = 0460, loss = 1.9312
[2022-08-14 00:29:58] iter = 0470, loss = 1.0497
[2022-08-14 00:30:06] iter = 0480, loss = 16.3134
[2022-08-14 00:30:14] iter = 0490, loss = 23.7197
-------------------------
Evaluation
model_train = ConvNetD4, model_eval = ConvNetD4, iteration = 500
DSA augmentation strategy:  color_crop_cutout_flip_scale_rotateDSA augmentation parameters: 
 {'aug_mode': 'S', 'prob_flip': 0.5, 'ratio_scale': 1.2, 'ratio_rotate': 15.0, 'ratio_crop_pad': 0.125, 'ratio_cutout': 0.5, 'ratio_noise': 0.05, 'brightness': 1.0, 'saturation': 2.0, 'contrast': 0.5, 'batchmode': False, 'latestseed': -1}Traceback (most recent call last):
  File "/media/ntu/volume1/home/s121md302_06/workspace/code/mtt-distillation/distill.py", line 496, in <module>
    main(args)
  File "/media/ntu/volume1/home/s121md302_06/workspace/code/mtt-distillation/distill.py", line 227, in main
    _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args, texture=args.texture)
  File "/media/ntu/volume1/home/s121md302_06/workspace/code/mtt-distillation/utils.py", line 400, in evaluate_synset
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
  File "/media/ntu/volume1/home/s121md302_06/anaconda3/envs/distillation/lib/python3.9/site-packages/torch/optim/sgd.py", line 91, in __init__
    raise ValueError("Invalid learning rate: {}".format(lr))
ValueError: Invalid learning rate: -0.00048201243043877184
GeorgeCazenavette commented 2 years ago

Hello!

You could wrap the learning rate in an absolute value, but that would not solve the underlying issue.

I would try lowering the learning rates for both the images and the learning rate first.

Although the loss being that large implies that the synthetic trajectory ended extremely far from the target. Losses should typically lie in [0,1].

You could also try lowering max_start_epoch. It's likely that the distance between the starting and target points is simply too small as a result of the expert being mostly converged at that point.

Please let me know if there are still issues after you try this!

cile98 commented 2 years ago

@c-liangyu in the Dataset Distillation paper (https://arxiv.org/abs/1811.10959) they used softplus on the learnable learning rate to prevent it from becoming negative, maybe that helps you

cliangyu commented 2 years ago

@GeorgeCazenavette A lower learning rate works for me. Thank you. @cile98 Thank you for the tips. I guess I will give a shot too : )