JohnRomanelis / SPVD

SPVD: Efficient and Scalable Point Cloud Generation with Sparse Point-Voxel Diffusion Models
https://johnromanelis.github.io/_spvd/
13 stars 0 forks source link

IndexError: max(): Expected reduction dim 0 to have non-zero size. #8

Open jianchaoci opened 1 week ago

jianchaoci commented 1 week ago

Dear,

Thank you for the great work! I am running your code for point cloud completion and get such an error when inferencing. I did some dedug, and realized using the code below, after self.update_rule(), the shape of x_t becomes (4, 4) from (2048,4). I guess this may be the problem. But I do not know the reason exactly. This makes the self.sample_step() can iterate only once. I am appreciated for your help!

    t_batch = torch.full((bs,), t, device=device, dtype=torch.long)
    # pdb.set_trace()

    # activate the model to predict the noise
    noise_pred = model((x_t, t_batch)) if emb is None else model((x_t, t_batch, emb))

    # calculate the new point coordinates
    x_t = self.update_rule(x_t, noise_pred, t, i, shape, device)
JohnRomanelis commented 1 week ago

That doesn’t sound quite right—x_t should typically keep a shape of (2048, 4) throughout the completion process. Could you share a bit more about the experiment you're running? For example, are you using the completion scheduler? And what type and shape of inputs are you working with? I'd be happy to help troubleshoot!

jianchaoci commented 1 week ago

Thank you very much! Yes, I am using DDPMSparseCompletionSchedulerGPU() scheduler. I checked the input of the point cloud is 2048, I use default values to generate the dataset and load the data. I debugged further, and realized the pts.shape = (1, 2048, 4) but coords.shape = (a small value, 3) after function batch_sparse_quantize_torch (). I think this is the problem. This only happens after the first iteration-- for i, t in enumerate(self.strategy.steps) when t is smaller than 999

def torch2sparse(self, pts:torch.Tensor, shape): pts = pts.reshape(shape)

    coords = pts[..., :3] # In case points have additional features
    coords = coords - coords.min(dim=1, keepdim=True).values
    pdb.set_trace()
    coords, indices = batch_sparse_quantize_torch(coords, voxel_size=self.pres, return_index=True, return_batch_index=False)
    feats = pts.view(-1, shape[-1])[indices]
JohnRomanelis commented 1 week ago

Are you using a model trained on completion or are you experimenting with a model with randomly initialized weights? A model that is not trained may result in points with duplicate (or really close) coordinates, that will be merged in the next step due to the voxelization.

jianchaoci commented 1 week ago

Maybe that is the case, I am using a model trained with 10 epoches. Maybe I can use more epoches? Could you share a trained model?

JohnRomanelis commented 1 week ago

Yes, absolutely! I’m quite confident this isn’t the exact checkpoint used in the paper, but it should work well for debugging. I’ve also included some sample data that you can use for completion. Hope this helps! https://drive.google.com/drive/folders/1pLkapwySaJrv1eJmOCt62eRrgTY-DCo2?usp=sharing

jianchaoci commented 1 week ago

image I am facing a problem to load your shared weights. I think becuase I am using a wrong model? These are the parameters to create the network: Screenshot from 2024-11-13 14-04-57

JohnRomanelis commented 1 week ago

The checkpoint weights are for the smallest SPVD variant. I see now that I took the code from an experimental version of the repo and the network initialization code is not available on github. You can define the get_model by running the following code:

from models.ddpm_unet_attn import SPVUnet
get_model = partial(SPVUnet, in_channels=4, voxel_size=0.1, nfs=(32, 64, 128, 256), num_layers=1, attn_chans=8, attn_start=3)
jianchaoci commented 1 week ago

Thank you for your patient response. I tried training the network with more epoches (~100). Then I see the problem was not there.

JohnRomanelis commented 1 week ago

Could you try using the TrainCompletion notebook? Use the lines of code from above to set the get_model and try loading the checkpoint I provided instead of training. This experiment should work. If not, let me know so we can find another solution.

jianchaoci commented 1 week ago

I had a quick test. Basiclly, the code works, the checkpoint was loaded successfully for completion. But I found the completion output was not that correct for some part point cloud. Screenshot from 2024-11-14 13-30-02 Screenshot from 2024-11-14 13-31-19 Here is the code for test: %cd /home/jianchao/git_opensource_pro/SPVD/ from models.ddpm_unet_attn import SPVUnet from functools import partial from pclab.learner import from utils.callbacks import from pclab.learner import Callback from functools import partial import torch import torch.nn as nn from utils.completion_schedulers import DDPMSparseCompletionSchedulerGPU from utils.visualization import quick_vis_batch, vis_pc_sphere

def pad(t, np): B, N, F = t.shape padded = torch.zeros(B, np, F).to(t) padded[:, :N, :] = t return padded

get_model = partial(SPVUnet, in_channels=4, voxel_size=0.1, nfs=(32, 64, 128, 256), num_layers=1, attn_chans=8, attn_start=3) model = get_model() model.load_state_dict(torch.load('/home/jianchao/git_opensource_pro/SPVD/CompletionSPVD_S_Chair.pt')['state_dict']) model = model.eval().cuda() sched = DDPMSparseCompletionSchedulerGPU() batch = next(iter(te_dl)) import time pc_batch = batch['input'].F.reshape(32, 2048, 4)[..., :3] mask_batch = batch['mask'] for idx in range(7, 32): pc = pc_batch[idx] print(pc.shape) mask = mask_batch[idx] pc = pc[mask] start = time.time() preds = sched.complete(pc.unsqueeze(0), model, n_points=2048, save_process=False) print(time.time() - start) quick_vis_batch(torch.cat([pad(pc.unsqueeze(0), 2048), preds], dim=0), grid=(2,1), x_offset=6)

JohnRomanelis commented 1 week ago

Could you also check that the data you’re using is properly normalized? You can use the sample data I provided along with the checkpoint to verify this. Also, this is likely not the exact checkpoint used for the paper results, which might explain some differences in performance.