Open yumion opened 5 years ago
In current machina implementation, observations are saved in flatten vector. So you should flatten observation like follwing code.
import argparse
import json
import os
from pprint import pprint
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from gym import ObservationWrapper
from pybullet_envs.bullet.kukaCamGymEnv import KukaCamGymEnv
from machina.pols import ArgmaxQfPol
from machina.algos import qtopt
from machina.vfuncs import DeterministicSAVfunc, CEMDeterministicSAVfunc
from machina.envs import GymEnv
from machina.traj import Traj
from machina.traj import epi_functional as ef
from machina.samplers import EpiSampler
from machina import logger
from machina.utils import set_device, measure
class FlattenedObservationWrapper(ObservationWrapper):
def __init__(self, env):
super(FlattenedObservationWrapper, self).__init__(env)
self.observation_space = env.observation_space
self.flattend_observation_space = env.observation_space.sample().\
reshape(-1)
def observation(self, observation):
return observation.reshape(-1)
class QTOptNet(nn.Module):
def __init__(self, observation_space, action_space):
super(QTOptNet, self).__init__()
# conv
self.conv1 = nn.Conv2d(4, 64, kernel_size=6, stride=2, padding=2)
self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
# pool
self.pool1 = nn.MaxPool2d(3, stride=3)
self.pool2 = nn.MaxPool2d(2, stride=2)
# fc
self.fc1 = nn.Linear(action_space.shape[0], 256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(4032, 64)
self.fc4 = nn.Linear(64, 64)
self.output_layer = nn.Linear(64, 1)
self.observation_space = observation_space
def forward(self, ob, ac):
batch_size = ob.shape[0]
ob = ob.reshape([batch_size] + list(self.observation_space.shape))
ob = ob.transpose(1, 2).transpose(1, 3) # convert to channel first
# observation net
ob = F.relu(self.conv1(ob))
ob = self.pool1(ob)
for i in range(6):
ob = F.relu(self.conv2(ob))
ob = self.pool1(ob)
# action net
ac = F.relu(self.fc1(ac))
ac = F.relu(self.fc2(ac))
ac = ac.view(-1, 64, 1, 1)
# tiled layer
ac_tiled = ac.repeat(1, 1, ob.size()[2], ob.size()[3])
# add action feature
h = ob + ac_tiled
for i in range(6):
h = F.relu(self.conv3(h))
h = self.pool2(h)
for i in range(3):
h = F.relu(self.conv3(h))
h = h.view(h.size()[0], -1) # flatten
h = F.relu(self.fc3(h))
h = F.relu(self.fc4(h))
out = torch.sigmoid(self.output_layer(h))
return out
parser = argparse.ArgumentParser()
parser.add_argument('--log', type=str, default='garbage',
help='Directory name of log.')
parser.add_argument('--record', action='store_true',
default=False, help='If True, movie is saved.')
parser.add_argument('--seed', type=int, default=256)
parser.add_argument('--max_epis', type=int,
default=100000000, help='Number of episodes to run.')
parser.add_argument('--max_steps_off', type=int,
default=1000000000000, help='Number of episodes stored in off traj.')
parser.add_argument('--num_parallel', type=int, default=4,
help='Number of processes to sample.')
parser.add_argument('--cuda', type=int, default=-1, help='cuda device number.')
parser.add_argument('--data_parallel', action='store_true', default=False,
help='If True, inference is done in parallel on gpus.')
parser.add_argument('--max_steps_per_iter', type=int, default=4000,
help='Number of steps to use in an iteration.')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--pol_lr', type=float, default=1e-4,
help='Policy learning rate.')
parser.add_argument('--qf_lr', type=float, default=1e-3,
help='Q function learning rate.')
parser.add_argument('--tau', type=float, default=0.0001,
help='Coefficient of target function.')
parser.add_argument('--gamma', type=float, default=0.9,
help='Discount factor.')
parser.add_argument('--lag', type=int, default=6000,
help='Lag of gradient steps of target function2.')
parser.add_argument('--num_iter', type=int, default=2,
help='Number of iteration of CEM.')
parser.add_argument('--num_sampling', type=int, default=60,
help='Number of samples sampled from Gaussian in CEM.')
parser.add_argument('--num_best_sampling', type=int, default=6,
help='Number of best samples used for fitting Gaussian in CEM.')
parser.add_argument('--multivari', action='store_true',
help='If true, Gaussian with diagonal covarince instead of Multivariate Gaussian matrix is used in CEM.')
parser.add_argument('--eps', type=float, default=0.2,
help='Probability of random action in epsilon-greedy policy.')
parser.add_argument('--loss_type', type=str,
choices=['mse', 'bce'], default='mse',
help='Choice for type of belleman loss.')
parser.add_argument('--save_memory', action='store_true',
help='If true, save memory while need more computation time by for-sentence.')
args = parser.parse_args()
if not os.path.exists(args.log):
os.mkdir(args.log)
with open(os.path.join(args.log, 'args.json'), 'w') as f:
json.dump(vars(args), f)
pprint(vars(args))
if not os.path.exists(os.path.join(args.log, 'models')):
os.mkdir(os.path.join(args.log, 'models'))
np.random.seed(args.seed)
torch.manual_seed(args.seed)
device_name = 'cpu' if args.cuda < 0 else "cuda:{}".format(args.cuda)
device = torch.device(device_name)
set_device(device)
score_file = os.path.join(args.log, 'progress.csv')
logger.add_tabular_output(score_file)
env = KukaCamGymEnv()
env = FlattenedObservationWrapper(env)
flattend_observation_space = env.flattend_observation_space
env = GymEnv(env, log_dir=os.path.join(
args.log, 'movie'), record_video=args.record)
env.env.seed(args.seed)
observation_space = env.observation_space
action_space = env.action_space
qf_net = QTOptNet(observation_space, action_space)
lagged_qf_net = QTOptNet(observation_space, action_space)
lagged_qf_net.load_state_dict(qf_net.state_dict())
targ_qf1_net = QTOptNet(observation_space, action_space)
targ_qf1_net.load_state_dict(qf_net.state_dict())
targ_qf2_net = QTOptNet(observation_space, action_space)
targ_qf2_net.load_state_dict(lagged_qf_net.state_dict())
qf = DeterministicSAVfunc(observation_space, action_space, qf_net,
data_parallel=args.data_parallel)
lagged_qf = DeterministicSAVfunc(
flattend_observation_space, action_space,
lagged_qf_net,
data_parallel=args.data_parallel)
targ_qf1 = CEMDeterministicSAVfunc(
flattend_observation_space, action_space,
targ_qf1_net, num_sampling=args.num_sampling,
num_best_sampling=args.num_best_sampling, num_iter=args.num_iter,
multivari=args.multivari, data_parallel=args.data_parallel,
save_memory=args.save_memory)
targ_qf2 = DeterministicSAVfunc(
flattend_observation_space, action_space, targ_qf2_net, data_parallel=args.data_parallel)
pol = ArgmaxQfPol(flattend_observation_space,
action_space, targ_qf1, eps=args.eps)
sampler = EpiSampler(env, pol, num_parallel=args.num_parallel, seed=args.seed)
optim_qf = torch.optim.Adam(qf_net.parameters(), args.qf_lr)
off_traj = Traj(args.max_steps_off, traj_device='cpu')
total_epi = 0
total_step = 0
total_grad_step = 0
num_update_lagged = 0
max_rew = -1e6
while args.max_epis > total_epi:
with measure('sample'):
epis = sampler.sample(pol, max_steps=args.max_steps_per_iter)
with measure('train'):
on_traj = Traj(traj_device='cpu')
on_traj.add_epis(epis)
on_traj = ef.add_next_obs(on_traj)
on_traj.register_epis()
off_traj.add_traj(on_traj)
total_epi += on_traj.num_epi
step = on_traj.num_step
total_step += step
epoch = step
if args.data_parallel:
qf.dp_run = True
lagged_qf.dp_run = True
targ_qf1.dp_run = True
targ_qf2.dp_run = True
result_dict = qtopt.train(
off_traj, qf, lagged_qf, targ_qf1, targ_qf2,
optim_qf, epoch, args.batch_size,
args.tau, args.gamma, loss_type=args.loss_type
)
if args.data_parallel:
qf.dp_run = False
lagged_qf.dp_run = False
targ_qf1.dp_run = False
targ_qf2.dp_run = False
total_grad_step += epoch
if total_grad_step >= args.lag * num_update_lagged:
logger.log('Updated lagged qf!!')
lagged_qf_net.load_state_dict(qf_net.state_dict())
num_update_lagged += 1
rewards = [np.sum(epi['rews']) for epi in epis]
mean_rew = np.mean(rewards)
logger.record_results(args.log, result_dict, score_file,
total_epi, step, total_step,
rewards,
plot_title=args.env_name)
if mean_rew > max_rew:
torch.save(pol.state_dict(), os.path.join(
args.log, 'models', 'pol_max.pkl'))
torch.save(qf.state_dict(), os.path.join(
args.log, 'models', 'qf_max.pkl'))
torch.save(targ_qf1.state_dict(), os.path.join(
args.log, 'models', 'targ_qf1_max.pkl'))
torch.save(targ_qf2.state_dict(), os.path.join(
args.log, 'models', 'targ_qf2_max.pkl'))
torch.save(optim_qf.state_dict(), os.path.join(
args.log, 'models', 'optim_qf_max.pkl'))
max_rew = mean_rew
torch.save(pol.state_dict(), os.path.join(
args.log, 'models', 'pol_last.pkl'))
torch.save(qf.state_dict(), os.path.join(
args.log, 'models', 'qf_last.pkl'))
torch.save(targ_qf1.state_dict(), os.path.join(
args.log, 'models', 'targ_qf1_last.pkl'))
torch.save(targ_qf2.state_dict(), os.path.join(
args.log, 'models', 'targ_qf2_last.pkl'))
torch.save(optim_qf.state_dict(), os.path.join(
args.log, 'models', 'optim_qf_last.pkl'))
del on_traj
del sampler
Instead of defining your own class of FlattenedObservationWrapper
as @iory says,
you can also use gym.wrappers.FlattenDictWrapper
(https://github.com/openai/gym/blob/master/gym/wrappers/dict.py#L8)
in order to flatten observation.
Thanks @iory and @takerfume !
It works without errors until sampling
, but it was killed while train
. This is error codes bellow.
fish: “python run_qtopt_bulletKukaCam.…” terminated by signal SIGSEGV (Address boundary error)
I guess there is no available memory on GPU, so I decrease batch_size
to 1 and num_sampling
to 6.
However, It works until 3rd sampling
, but it was killed while 3rd train
again.
My environment is here.
CPU: Intel i9-7920X CPU @ 2.90GHz × 24
GPU: TITAN RTX (25GB), TITAN V(12GB) x2
RAM: 125.6GB
OS: Ubuntu 16.04
Is it something that can not be helped? Please give me advice.
Hello, I'm always indebted to machina.
I implement QT-Opt for grasping tasks at PyBullet(https://github.com/bulletphysics/bullet3/blob/master/examples/pybullet/gym/pybullet_envs/baselines/train_kuka_cam_grasping.py). Observation of this environment is a RGBD image(
observation_space = Box(341, 256, 4)
) and action is displacement of x, y, gripper angle(action_space = Box(3,)
).I built CNN for inputting a image, but I got this error when sampling trajectory.
I tried on another environment that is like former environment but observation is not RGBD image (
observation_space = Box(9, )
) (so NN is not CNN but MLP) , but it worked without error. Is there any problem using CNN...? (or my network is wrong?) Please give an advice to solve this error.For reference, here is my NN architecture.