HiroIshida / hifuku

Code for paper https://arxiv.org/abs/2405.02968
3 stars 0 forks source link

Debug tbdr task #20

Open HiroIshida opened 1 year ago

HiroIshida commented 1 year ago

実行可能性が確率的なのはなぜだろう Screenshot from 2023-07-05 22-22-24

HiroIshida commented 1 year ago

上のやつは, res.initが正しくセットされていなかった. それをなおしてたら次のようにちゃんとなった! Screenshot from 2023-07-05 23-01-20

HiroIshida commented 1 year ago

RRT result. 同じ問題でもとけたりとけなかったり. (ert 0.5) Screenshot from 2023-07-05 23-51-11

ert 0.1 Screenshot from 2023-07-05 23-54-15

import tqdm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from hifuku.llazy.dataset import LazyDecomplessDataset
from hifuku.types import RawData
from hifuku.domain import TBDR_SQP_Domain, TBDR_RRT_Domain
from hifuku.datagen.batch_solver import DistributedBatchProblemSolver, MultiProcessBatchProblemSolver
from rpbench.pr2.tabletop import TabletopBoxDualArmReachingTask
from pathlib import Path
import pickle
import tqdm
from skmp.solver.ompl_solver import OMPLSolver, OMPLSolverConfig
from skmp.solver.nlp_solver.sqp_based_solver import SQPBasedSolverConfig, SQPBasedSolver
from rpbench.pr2.common import PR2InteractiveTaskVisualizer
from rpbench.pr2.tabletop import TabletopBoxDualArmReachingTask

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-n", type=int, default=100, help="n_sample")
parser.add_argument("--nocache", action="store_true")
args = parser.parse_args()
n_sample: int = args.n
nocache: bool = args.nocache

domain = TBDR_RRT_Domain
standard_task = domain.task_type.sample(1, standard=True)
result = standard_task.solve_default()[0]

# solver = OMPLSolver.init(OMPLSolverConfig(simplify=True, n_max_call=100000))
# solver.setup(standard_task.export_problems()[0])
# result = solver.solve()
assert result.traj is not None
init_solution = result.traj

pcache = Path("/tmp/hifuku-plot-caceh-{}.pkl".format(n_sample))
print("cache path: {}".format(pcache))

use_batch = True

if pcache.exists() and not nocache:
    print("laod cache")
    with pcache.open(mode = "rb") as f:
        (tasks, resultss) = pickle.load(f)
else:
    tasks = [domain.task_type.sample(1, standard=False) for _ in range(n_sample)]
    if use_batch:
        print("create cache")
        batch_solver = DistributedBatchProblemSolver(domain.solver_type, domain.solver_config)
        # batch_solver = MultiProcessBatchProblemSolver(domain.solver_type, domain.solver_config)
        resultss = batch_solver.solve_batch(tasks, [init_solution] * n_sample)
    else:
        resultss = []
        solver = domain.solver_type.init(domain.solver_config)
        for task in tqdm.tqdm(tasks):
            problem = task.export_problems()[0]
            solver.setup(problem)
            results = solver.solve(init_solution)
            resultss.append([results])

    with pcache.open(mode = "wb") as f:
        pickle.dump((tasks, resultss), f)

dataset = []
for task, results in tqdm.tqdm(zip(tasks, resultss)):
    x, y, z = task.world.box.worldpos()
    result = results[0]
    if result.traj is None:
        n_iter = domain.solver_config.n_max_call
    else:
        n_iter = result.n_call
    dataset.append((x, y, n_iter))

xs, ys, iters = zip(*dataset)

plt.scatter(xs, ys, c=iters, cmap='viridis')  # Use 'c' to set color based on values
plt.colorbar()  # Add colorbar to show the value scale

plt.xlabel('X')
plt.ylabel('Y')
plt.title('Scatter Plot')
plt.show()