PeizhuoLi / manifold-aware-transformers

A novel manifold-aware transfomer architecture for predicting garment dynamics on unseen geometries [EUROGRAPHICS 2024]
22 stars 5 forks source link

Training Error #4

Open handhp1 opened 3 days ago

handhp1 commented 3 days ago

Hi, Thanks to nice work!

During training vto-dataset, I got an error below:

Traceback (most recent call last):
  File "train_frame_based.py", line 231, in <module>
    main()
  File "train_frame_based.py", line 227, in main
    train(0, world_size, args, t_model)
  File "/root/anaconda3/envs/manifold-aware-transformers/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File "/root/anaconda3/envs/manifold-aware-transformers/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/root/anaconda3/envs/manifold-aware-transformers/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/root/anaconda3/envs/manifold-aware-transformers/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/mnt/d/wkspace/manifold-transformer/dataset/handle_dataset.py", line 918, in __getitem__
    res = self.datasets[i][idx]
  File "/mnt/d/wkspace/manifold-transformer/dataset/handle_dataset.py", line 748, in __getitem__
    in_dict2 = self.get_repr_from_pos(relative_stretch=relative_stretch, centroids=centroids, orients=orients,
  File "/mnt/d/wkspace/manifold-transformer/dataset/handle_dataset.py", line 814, in get_repr_from_pos
    body_pos_combined = torch.stack([body_pos0, body_pos], dim=0)
TypeError: expected Tensor as element 0 in argument 0, but got NoneType

How can I fix this? ( There is no body_pos. ) For base_deformation, body_pos is needed.

#def get_repr_from_pos
  if base_deformation is None:
      body_pos_combined = torch.stack([body_pos0, body_pos], dim=0)
      base_deformation = self.repr_get_base_deform(vert_pos_gt[-1:], body_pos_combined, cfg=cfg)[0]
handhp1 commented 3 days ago

I used the option provided as an example and it worked. Thank you.

HackerHuangZY commented 23 minutes ago

I used the option provided as an example and it worked. Thank you.

hello, can you share your method to solve the problem? Thank you.