google / nerfies

This is the code for Deformable Neural Radiance Fields, a.k.a. Nerfies.
https://nerfies.github.io
Apache License 2.0
1.59k stars 215 forks source link

Implementation about the 2D Toy #41

Open longbowzhang opened 2 years ago

longbowzhang commented 2 years ago

Hi @keunhong

I am quite interested in the 2D toy example as shown in Fig. 23 and Fig. 24 but I failed to reproduce the results.

Therefore, I am wondering whether you could kindly share the colab file with me? If not so, could you kindly answer my following questions.

context: to simplify the example, I create a 2D toy example with only rigid transformation (i.e., only random rotation in the range of [-180°, 180°] and translation), and of course I simplify the warp network to only depend on the GLO embedding vector to product a uniform rigid transformation for each observation image.

Q1: However, I find it is really hard for the network to output reasonable rotations (any hints on this)? Q2: as you have mentioned in the paper about the difficulty due to the "Orientation flips", do you face a similar problem in the 2D toy setting (I noticed that your 2D toy example also uses random rotation within the range of [-180°, 180°] )

Thank you very much in advance. Best, Longbow

keunhong commented 2 years ago

Here you go!

https://github.com/google/hypernerf/blob/main/notebooks/figures/nerfies_2d_experiments.ipynb

longbowzhang commented 2 years ago

Hi @keunhong Thank you so much for your sharing! Now I can reproduce similar results with my own dataset by (1) use a much lower lr (e.g., 1e-4) (2) set the template_min_freq to -2.

But I still have several questions as follows.

  1. where is the defination of Model (failed to find it)?
nn.vmap(Model, in_axes=(0, 0, None, None), variable_axes={'params': None}, split_rngs={'params': False})
model = VModel(num_glo_embeddings=len(images), deform_type=deform_type, ...
  1. I am extremly curious by the fact that you set template_min_freq to -2. Could you kindly give me some motivations on setting the min_freq less than 0 ?

  2. In the 2D case, you do not need to handle the so-called "Orientation flips" problem as in 3D? why?

  3. I notice that you used a fixed learning rate for adam. Then why bother to define the learning rate schedule?

    lr_schedule = schedules.from_config({
    'type': 'delayed',
    'delay_steps': 50,
    'delay_mult': 0.01,
    'base_schedule': {
    'type': 'exponential',
    'initial_value': 8e-3,
    'final_value': 8e-5,
    'num_steps': max_iters,
    },
    })
    optimizer_def = optim.Adam(lr_schedule(0))

Best :)