ziyc / drivestudio

A 3DGS framework for omni urban scene reconstruction and simulation.
https://ziyc.github.io/omnire/
MIT License
565 stars 43 forks source link

failed at start of train.py #15

Closed lcc815 closed 1 month ago

lcc815 commented 1 month ago

thanks for your work. My issue is:

I run python tools/train.py --config_file configs/omnire.yaml --output_root output --project omnire --run_name first_exp dataset=waymo/3cams data.scene_idx=114 data.start_timestep=0 data.end_timestep=-1

but I got

Traceback (most recent call last):
  File "tools/train.py", line 376, in <module>
    final_step = main(args)
  File "tools/train.py", line 200, in main
    render_results = render_images(
  File "***/drivestudio/models/video_utils.py", line 62, in render_images
    render_results = render(
  File "***/drivestudio/models/video_utils.py", line 140, in render
    results = trainer(image_infos, cam_infos)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "***/drivestudio/models/trainers/scene_graph.py", line 235, in forward
    gs = self.collect_gaussians(
  File "***/drivestudio/models/trainers/base.py", line 356, in collect_gaussians
    gs = self.models[class_name].get_gaussians(cam)
  File "***/drivestudio/models/nodes/smpl.py", line 356, in get_gaussians
    world_means, world_quats = self.transform_means_and_quats(self._means, self._quats)
  File "***/drivestudio/models/nodes/smpl.py", line 293, in transform_means_and_quats
    W, A = self.template(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "***/drivestudio/models/human_body.py", line 178, in forward
    W = self.voxel_deformer(xyz_canonical)  # B,N,24+K
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "***/drivestudio/models/modules.py", line 622, in forward
    w = w.squeeze(3, 4).permute(0, 2, 1)
TypeError: squeeze() received an invalid combination of arguments - got (int, int), but expected one of:
 * ()
 * (int dim)
 * (name dim)
ziyc commented 1 month ago

Hi @lcc815, thanks for reporting this. I haven't encountered this bug before. Could you run some other Waymo scenes to see if this issue exists there as well? This will help me determine whether it's a data-specific problem or a code issue. I appreciate your help!

lcc815 commented 1 month ago

Running following codes can reproduce this error:

import torch
a = torch.randn(3,4,5,1,1)
a.squeeze(3, 4)

So I guess this is a torch-version-not-match issue. I changed the line mentioned above as w = w.squeeze(3).squeeze(3).permute(0, 2, 1) and solved this issue.