NVlabs / nvdiffrec

Official code for the CVPR 2022 (oral) paper "Extracting Triangular 3D Models, Materials, and Lighting From Images".
Other
2.09k stars 222 forks source link

Not able to train second pass mesh #118

Open iraj465 opened 1 year ago

iraj465 commented 1 year ago
Base mesh has 19035 triangles and 10472 vertices.                                                                                                                                                                                                    
Writing mesh:  out/manual_SHOE/dmtet_mesh/mesh.obj                                                                                                                                                                                                   
    writing 10472 vertices                                                                                                                                                                                                                           
    writing 20702 texcoords                                                                                                                                                                                                                          
    writing 10472 normals                                                                                                                                                                                                                            
    writing 19035 faces                                                                                                                                                                                                                              
Writing material:  out/manual_SHOE/dmtet_mesh/mesh.mtl                                                                                                                                                                                               
Done exporting mesh           
------------------------------                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
reflvec.shape[0] 1                                                                                                                                                                                                                                   
reflvec.shape[1] * reflvec.shape[2] 921600                                                                                                                                                                                                           
reflvec.shape[3] 3                                                                                                                                                                                                                                   
*reflvec.shape  1 720 1280 3                                                                                                                                                                                                                         
mtx  torch.Size([1, 4, 4])                                                                                                                                                                                                                           
------------------------------                                                                                                                                                                              
reflvec.shape[0] 2                                                                                                                                                                                                                                   
reflvec.shape[1] * reflvec.shape[2] 921600                                                                                                                                                                                                           
reflvec.shape[3] 3                                                                                                                                                                                                                                   
*reflvec.shape  2 720 1280 3                                                                                                                                                                                                                         
mtx  torch.Size([1, 4, 4])         
------------------------------                                                                                                                                                                                                                    
Traceback (most recent call last):                                                                                                                                                                                                                   
  File "train.py", line 632, in <module>
    geometry, mat = optimize_mesh(glctx, geometry, base_mesh.material, lgt, dataset_train, dataset_validate, FLAGS,
  File "train.py", line 420, in optimize_mesh
    img_loss, reg_loss = trainer(target, it)
  File "/Models/conda_envs/saptarshi.majumder/envs/nvdiffrec-prod/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                                                                                                 
    return forward_call(*input, **kwargs)
  File "/Models/conda_envs/saptarshi.majumder/envs/nvdiffrec-prod/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward                                                                                              
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/Models/conda_envs/saptarshi.majumder/envs/nvdiffrec-prod/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 971, in _run_ddp_forward                                                                                      
    return module_to_run(*inputs, **kwargs)
  File "/Models/conda_envs/saptarshi.majumder/envs/nvdiffrec-prod/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                                                                                                 
    return forward_call(*input, **kwargs)
  File "train.py", line 300, in forward
    return self.geometry.tick(glctx, target, self.light, self.material, self.image_loss_fn, it)
  File "/home/saptarshi.majumder/bifrost/geometry/dlmesh.py", line 56, in tick
    buffers = self.render(glctx, target, lgt, opt_material)
  File "/home/saptarshi.majumder/bifrost/geometry/dlmesh.py", line 48, in render
    return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],                                                                                                                       
  File "/home/saptarshi.majumder/bifrost/render/render.py", line 233, in render_mesh
    layers += [(render_layer(rast, db, mesh, view_pos, lgt, resolution, spp, msaa, bsdf), rast)]
  File "/home/saptarshi.majumder/bifrost/render/render.py", line 166, in render_layer
    buffers = shade(gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_texc, gb_texc_deriv,
  File "/home/saptarshi.majumder/bifrost/render/render.py", line 79, in shade
    shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True)
  File "/home/saptarshi.majumder/bifrost/render/light.py", line 106, in shade
    reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape)                                                                                                       
RuntimeError: shape '[2, 720, 1280, 3]' is invalid for input of size 2764800
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 16526 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 16527 closing signal SIGTERM

Just after saving the dmtet_pass, the training is failing. I have printed the required shapes to be more clear, also batch_size of 2 is taken for 720x1280 res images.

constantm commented 1 year ago

Which GPU are you using? In my experience the second pass failing is usually due to either running out of VRAM, or the initial mesh pass output being of very low quality. What was the loss you were seeing on the first pass?

iraj465 commented 1 year ago

The img-loss gets somewhat constant after 100-200 iterations

jmunkberg commented 1 year ago

From the log above, it looks like the reshape operation is failing

reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape)                                                                                                       
RuntimeError: shape '[2, 720, 1280, 3]' is invalid for input of size 2764800

which is strange, as I assume reflvec has shape '[2, 720, 1280, 3] not '[1, 720, 1280, 3] when this triggered.

That said, that part of the code should not execute unless you have "camera_space_light" : true, set in the config. This should only be set if you have a setup with moving object and static camera and light. If this is not your capture setup, results will be really bad.

In most other examples, we have a static object, static light and only the camera is moving, and for that setup, you should have "camera_space_light" : false, or remove that line from the config.

iraj465 commented 1 year ago

I have static camera and light but the object is rotating in a turntable. We would need camera_space_light: true in this case right? Considering camera_space_light: true being set, how can i resolve this?