NVIDIA / MinkowskiEngine

Minkowski Engine is an auto-diff neural network library for high-dimensional sparse tensors
https://nvidia.github.io/MinkowskiEngine
Other
2.43k stars 357 forks source link

Sampling from VAE's latent space #566

Open seakforzq opened 11 months ago

seakforzq commented 11 months ago

Is your feature request related to a problem? Please describe. How to sample from the vae's latent space?


        target_key = sin.coords_key
        out_cls, targets, sout, means, log_vars, zs = net(sin, target_key)
        num_layers, BCE = len(out_cls), 0
        losses = []
        for out_cl, target in zip(out_cls, targets):
            curr_loss = crit(out_cl.F.squeeze(), target.type(out_cl.F.dtype).to(device))
            losses.append(curr_loss.item())
            BCE += curr_loss / num_layers

        KLD = -0.5 * torch.mean(
            torch.sum(1 + log_vars.F - means.F.pow(2) - log_vars.F.exp(), 1)
        )
        loss = KLD + BCE

        print(loss)

        batch_coords, batch_feats = sout.decomposed_coordinates_and_features
        for b, (coords, feats) in enumerate(zip(batch_coords, batch_feats)):
            pcd = PointCloud(coords)
            pcd.estimate_normals()
            pcd.translate([0.6 * config.resolution, 0, 0])
            pcd.rotate(M)
            opcd = PointCloud(data_dict["xyzs"][b])
            opcd.translate([-0.6 * config.resolution, 0, 0])
            opcd.estimate_normals()
            opcd.rotate(M)
            o3d.visualization.draw_geometries([pcd, opcd])

            n_vis += 1
            if n_vis > config.max_visualization:
                return

Describe the solution you'd like The code above only reaches the reconstruction purpose, and target_key from sin is needed, can you help me with sampling from the latent space?

Describe alternatives you've considered NO

Additional context NO

seakforzq commented 11 months ago

And why we need gt_target to prune the upsampled sparsetensor?


        # If training, force target shape generation, use net.eval() to disable
        if self.training:
            keep1 += target

I think the network will not learn how to pruning the output in this way right?