google-research / multinerf

A Code Release for Mip-NeRF 360, Ref-NeRF, and RawNeRF
Apache License 2.0
3.57k stars 338 forks source link

Question about parameter modification for low batch size #82

Closed dogyoonlee closed 1 year ago

dogyoonlee commented 1 year ago

I'm using two RTX 3090 to run the multinerf(especially raw-nerf).

Due to OOM issue, I modify batch_size and render_chunk_size as 2048, respectively. Safely, it works but I wonder the extra modification to reproduce the results of the paper.

For example, if I use batch_size and render_chunk_size as 1/N compared to 16384, which is original setting of this repo, training iteration should be N*500000?

In addition, is it necessary to modify learning rate, decay steps, and other parameters?

Thank you for your awesome work!

0LEL0 commented 1 year ago

I have the same question, waiting for reply.

0LEL0 commented 1 year ago

But I tried to put it on TPUs, it worked on 16384 stably, but I found it too slow that MXU was only up to 20%, what should I modify to let the TPUs work more efficiently.

jonbarron commented 1 year ago

Yes, if you reduce the batch size you should increase the number of iterations and decrease the learning rate. This is usually referred to as the "linear scaling rule", there's a lot of information available about this online.

dogyoonlee commented 1 year ago

@jonbarron Great! Thanks a lot :).

In addition, I have one more question.

I train raw-nerf using train_raw.sh on morningkitchen scene in raw-nerf dataset.

However, psnr have increased around 52 during training, but it fail in evaluation with psnr around 13.

All I modified parameter are batch_size and render_chunk_size in llff_raw.gin file and name of the scene as morningkitchen to train in train_raw.sh.

Is there any problems to train raw-nerf?

jonbarron commented 1 year ago

You changed the batch size but not the learning rate and the number of iterations? If so, you should change the learning rate and number of iterations as per the linear scaling rule.

dogyoonlee commented 1 year ago

I miss to notice the detailed parameter I used in training. I tried training hyper parameter as follows since I use 2 gpu(RTX 3090). As I understand the linear scaling rule, it is important that how many batch is computed per gpu. Hence I modified learning rate as 1/8 from original parameter since raw-nerf use 16 TPU in original setting as I know. In addition, I modified learning rate delay steps and max iteration as 8 times of original setting to train full optimization. But it doesn't work and still show poor evaluation performance(around 15 PSNR) despite of high training performance(around 52 PSNR). I stopped training in 740000 steps since it still shows poor performances as follows image Modified hyperparameters in llff_raw.gin:

Config.batch_size = 2048
Config.render_chunk_size = 2048
Config.lr_init = 0.000125
Config.lr_final = 0.00000125
Config.max_steps = 4000000
Config.checkpoint_every = 25000
Config.lr_delay_steps = 20000
Config.lr_delay_mult = 0.01
Config.grad_max_norm = 0.1
Config.grad_max_val = 0.1
Config.adam_eps = 1e-8

Is there any wrong here?

Now I'm training with Config.lr_init = 0.0000625 and Config.lr_final = 0.000000625. Thank you for your help!

jonbarron commented 1 year ago

I don't think the number of GPUs is relevant here.

dogyoonlee commented 1 year ago

@jonbarron I'm sorry to bother you again. I tried many values for Config.batch_size, Config.render_chunk_size, Config.lr_init, Config.lr_final, Config.max_steps, and Config.lr_delay_steps on 2 GPU(RTX 3090). But none of them worked. Especially, if training iteration reach the specific step(it varies along the hyperparameter setting) training PSNR drastically fall. When I set the lr_init and lr_final as 1.5625e-5 and 1.5625e-7, which is really small compared to original setting, training PSNR can increase until around 17, but it fall again after the 6400 iteration(with lr_decay_steps=160000). I suppose there is also the problem of warmup iteration setting(lr_decay_steps), independent to learning rate.

I will try another hyperparameters as you said according to linear scaling weight.

Again, thank you for you help and awesome work!!