HiroIshida / hifuku

Code for paper https://arxiv.org/abs/2405.02968
3 stars 0 forks source link

visualize library #13

Open HiroIshida opened 1 year ago

HiroIshida commented 1 year ago
import numpy as np
import time
import uuid

import numpy as np
import torch
from mohou.file import get_project_path
from rpbench.pr2.common import InteractiveTaskVisualizer, StaticTaskVisualizer

from hifuku.domain import Kivapod_Empty_RRT_Domain
from hifuku.library import LibraryBasedSolver, SolutionLibrary

if __name__ == "__main__":
    domain = Kivapod_Empty_RRT_Domain
    task_type = domain.task_type
    solver_type = domain.solver_type
    interactive = False

    pp = get_project_path("tabletop_solution_library-{}".format(domain.get_domain_name()))
    libraries = SolutionLibrary.load(pp, task_type, solver_type, torch.device("cpu"))
    lib = libraries[0]
    task = lib.task_type.sample(1, standard=True)
    lb, ub = task.get_intrinsic_description_bound()

    n_grid = 50
    xlin, ylin, zlin = [np.linspace(lb[i], ub[i], n_grid) for i in range(3)]
    X, Y, Z = np.meshgrid(xlin, ylin, zlin)
    pts = np.array(list(zip(X.flatten(), Y.flatten(), Z.flatten())))
    print(pts)

    table = task.export_table()

    dummy_desc = torch.empty((len(pts), 0))
    vec = torch.from_numpy(table.get_vector_descs()[0]).float().unsqueeze(0)
    vecs = vec.repeat(len(pts), 1)
    print(vecs[0])
    vecs[:, 6:9] = torch.from_numpy(pts).float()
    print(vecs[0])

    idx = 5
    values  = lib.success_iter_threshold() - (lib.predictors[idx].forward((dummy_desc, vecs))[0] + lib.margins[idx])
    values = values.detach().numpy()
    print(lib.solver_config.n_max_call)
    print(lib.success_iter_threshold())
    print(np.min(values))
    print(np.max(values))

    # data = np.reshape(values, (n_grid, n_grid, n_grid))

    import plotly.graph_objects as go
    fig = go.Figure(
        go.Isosurface(
            x=X.flatten(),
            y=Y.flatten(),
            z=Z.flatten(),
            value=values,
            opacity=1.0,
            surface_count=1,
            isomin=0.0,
            isomax=1000,
        )
    )
    fig.show()
HiroIshida commented 1 year ago
import numpy as np
import torch
from mohou.file import get_project_path

from hifuku.domain import Kivapod_Empty_RRT_Domain, HumanoidTableRarmReaching_SQP_Domain, HumanoidTableReachingTask
from hifuku.library import SolutionLibrary
from skrobot.model import Box
import trimesh
from skimage import measure
import time
from skrobot.model.primitives import MeshLink, Axis, LineString
from skrobot.viewers import TrimeshSceneViewer

if __name__ == "__main__":
    domain = HumanoidTableRarmReaching_SQP_Domain
    task_type = domain.task_type
    solver_type = domain.solver_type
    interactive = False

    pp = get_project_path("tabletop_solution_library-{}".format(domain.get_domain_name()))
    libraries = SolutionLibrary.load(pp, task_type, solver_type, torch.device("cpu"))
    lib = libraries[0]
    task = lib.task_type.sample(1, standard=True)
    region: Box = task.world.target_region

    n_grid = 50
    region_center = region.worldpos()
    extent = np.array(region._extents)
    margin = 0.5
    lb = region_center - (0.5 + margin) * extent
    ub = region_center + (0.5 + margin) * extent
    xlin, ylin, zlin = [np.linspace(lb[i], ub[i], n_grid) for i in range(3)]
    X, Y, Z = np.meshgrid(xlin, ylin, zlin)
    pts = np.array(list(zip(X.flatten(), Y.flatten(), Z.flatten())))

    table = task.export_table()

    dummy_desc = torch.empty((len(pts), 0))
    vec = torch.from_numpy(table.get_vector_descs()[0]).float().unsqueeze(0)
    vecs = vec.repeat(len(pts), 1)
    print(vecs[0])
    vecs[:, 6:9] = torch.from_numpy(pts).float()
    print(vecs[0])

    mesh_links = []
    idx = 0
    for idx in range(len(lib.predictors)):
        print(idx)
        values = lib.success_iter_threshold() - (
            lib.predictors[idx].forward((dummy_desc, vecs))[0] + lib.margins[idx]
        )
        values = values.detach().numpy()
        print(lib.solver_config.n_max_call)
        print(lib.success_iter_threshold())
        print(np.min(values))
        print(np.max(values))
        if np.max(values) > 0.0:
            spacing = (ub - lb)/(n_grid-1)
            F = values.reshape(n_grid, n_grid, n_grid)
            F = np.swapaxes(F, 0, 1) # important!!!
            verts, faces, _, _ = measure.marching_cubes_lewiner(F, 0, spacing=spacing)
            verts = verts + lb
            faces = faces[:, ::-1]

            mesh = trimesh.Trimesh(vertices=verts, faces=faces)
            mesh = trimesh.smoothing.filter_laplacian(mesh)

            mesh_link = MeshLink(mesh)
            mesh_link.visual_mesh.visual.face_colors[:, 0] = 255
            mesh_link.visual_mesh.visual.face_colors[:, 1] = 0
            mesh_link.visual_mesh.visual.face_colors[:, 2] = 0
            mesh_link.visual_mesh.visual.face_colors[:, 3] = 150
            mesh_links.append(mesh_link)

    task: HumanoidTableReachingTask
    config = task.config_provider.get_config()
    efkin = config.get_endeffector_kin()
    jaxon = task.config_provider.get_jaxon()

    line_links = []
    for pred in lib.predictors:
        traj = pred.initial_solution

        feature_pointss, _ = efkin.map(traj.numpy())

        n_wp, n_feature, n_tspace = feature_pointss.shape
        feature_points = feature_pointss[:, 0, :]

        pts = []
        for feature_point in feature_points:
            pt = feature_point[:3]
            pts.append(pt)
        line_link = LineString(np.array(pts))
        line_links.append(line_link)

    vis = TrimeshSceneViewer()
    task.world.visualize(vis)
    co = task.descriptions[0][0]
    axis = Axis.from_coords(co)
    vis.add(jaxon)
    vis.add(axis)
    for link in line_links:
        vis.add(link)

    for mesh_link in mesh_links:
        vis.add(mesh_link)
    vis.redraw()
    vis.show()
    time.sleep(100)