robotgradient / grasp_diffusion

Pytorch implementation of diffusion models on Lie Groups for 6D grasp pose generation https://sites.google.com/view/se3dif/home
MIT License
234 stars 24 forks source link

Minor: `n_grasps`, `n_envs` and `batch` usage in sampling scripts #4

Closed kuldeepbrd1 closed 1 year ago

kuldeepbrd1 commented 1 year ago

The scripts in scripts/sample do not use n_grasps from the cli arguments.

n_grasps and n_envs in not used. Instead it always generate number of grasps equal to batch size specified in get_approximated_grasp_diffusion_field(...), as here: https://github.com/TheCamusean/grasp_diffusion/blob/3a2cb1448270798435479ee7cf8d1fbd9d5127c5/scripts/sample/generate_partial_pointcloud_6d_grasp_poses.py#L28

Batchwise sampling of n_grasps would be nice. This also avoids CUDA/cpu memory errors when large n_grasps is high. Something like this in main(...):

if __name__ == "__main__":
    ...
    n_grasps = int(args.n_grasps)
    obj_id = int(args.obj_id)
    obj_class = args.obj_class

    batch_size = 10

    ## Set Model and Sample Generator ##
    P, mesh = sample_pointcloud(obj_id, obj_class)
    generator, model = get_approximated_grasp_diffusion_field(
        P, args, batch=batch_size, device=device
    )

    H_batches = []
    batches = int(np.ceil((n_grasps / batch_size)))
    for i in range(0, batches):
        H_batches.append(generator.sample())

    H = torch.concatenate(H_batches, 0)
    H[..., :3, -1] *= 1 / 8.0
    ...

and get_approximated_grasp_diffusion_field(...) changed to

def get_approximated_grasp_diffusion_field(p, args, batch=10, device="cpu"):
    model_params = args.model

    ## Load model
    model_args = {"device": device, "pretrained_model": model_params}
    model = load_model(model_args)

It's not super critical to add this to code, so I highlight here. (Also, happy to also create a pull request, if you require)

robotgradient commented 1 year ago

Thanks @kuldeepbrd1 .

You are right! This would be highly benefitial to sample properly the grasps. If you feel up to it, create a pull request, I tested it and if everything works smoothly, I accept it.

Thanks alot :)