DeepX-inc / machina

Control section: Deep Reinforcement Learning framework
MIT License
279 stars 43 forks source link

Can't use CNN? #248

Open yumion opened 5 years ago

yumion commented 5 years ago

Hello, I'm always indebted to machina.

I implement QT-Opt for grasping tasks at PyBullet(https://github.com/bulletphysics/bullet3/blob/master/examples/pybullet/gym/pybullet_envs/baselines/train_kuka_cam_grasping.py). Observation of this environment is a RGBD image( observation_space = Box(341, 256, 4) ) and action is displacement of x, y, gripper angle( action_space = Box(3,) ).

I built CNN for inputting a image, but I got this error when sampling trajectory.

Traceback (most recent call last):
  File "~/.pyenv/versions/anaconda3-5.3.0/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "~/.pyenv/versions/anaconda3-5.3.0/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "~/.pyenv/versions/anaconda3-5.3.0/lib/python3.7/site-packages/machina/samplers/epi_sampler.py", line 126, in mp_sample
    l, epi = one_epi(env, pol, deterministic_flag, prepro)
  File "~/.pyenv/versions/anaconda3-5.3.0/lib/python3.7/site-packages/machina/samplers/epi_sampler.py", line 51, in one_epi
    ac_real, ac, a_i = pol(torch.tensor(o, dtype=torch.float))
  File "~/.pyenv/versions/anaconda3-5.3.0/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "~/.pyenv/versions/anaconda3-5.3.0/lib/python3.7/site-packages/machina/pols/argmax_qf_pol.py", line 48, in forward
    q, ac = self.qfunc.max(obs)
  File "~/.pyenv/versions/anaconda3-5.3.0/lib/python3.7/site-packages/machina/vfuncs/state_action_vfuncs/cem_state_action_vfunc.py", line 77, in max
    obs = obs.repeat((1, self.num_sampling)).reshape(
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

I tried on another environment that is like former environment but observation is not RGBD image ( observation_space = Box(9, ) ) (so NN is not CNN but MLP) , but it worked without error. Is there any problem using CNN...? (or my network is wrong?) Please give an advice to solve this error.

For reference, here is my NN architecture.

class QTOptNet(nn.Module):
    def __init__(self, observation_space, action_space):
        super(QTOptNet, self).__init__()
        # conv
        self.conv1 = nn.Conv2d(4, 64, kernel_size=6, stride=2, padding=2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        # pool
        self.pool1 = nn.MaxPool2d(3, stride=3)
        self.pool2 = nn.MaxPool2d(2, stride=2)
        # fc
        self.fc1 = nn.Linear(action_space.shape[0], 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(3136, 64)
        self.fc4 = nn.Linear(64, 64)
        self.output_layer = nn.Linear(64, 1)

    def forward(self, ob, ac):
        ob = ob.transpose(1,2).transpose(1,3)  # convert to channel first
        # observation net
        ob = F.relu(self.conv1(ob))
        ob = self.pool1(ob)
        for i in range(6):
            ob = F.relu(self.conv2(ob))
        ob = self.pool1(ob)
        # action net
        ac = F.relu(self.fc1(ac))
        ac = F.relu(self.fc2(ac))
        ac = ac.view(-1, 64, 1, 1)
        # tiled layer
        ac_tiled = ac.repeat(1, 1, ob.size()[2], ob.size()[3])
        # add action feature
        h = ob + ac_tiled
        for i in range(6):
            h = F.relu(self.conv3(h))
        h = self.pool2(h)
        for i in range(3):
            h = F.relu(self.conv3(h))
        h = h.view(h.size()[0], -1)  # flatten
        h = F.relu(self.fc3(h))
        h = F.relu(self.fc4(h))
        out = torch.sigmoid(self.output_layer(h))
        return out
iory commented 5 years ago

In current machina implementation, observations are saved in flatten vector. So you should flatten observation like follwing code.

import argparse
import json
import os
from pprint import pprint

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from gym import ObservationWrapper
from pybullet_envs.bullet.kukaCamGymEnv import KukaCamGymEnv

from machina.pols import ArgmaxQfPol
from machina.algos import qtopt
from machina.vfuncs import DeterministicSAVfunc, CEMDeterministicSAVfunc
from machina.envs import GymEnv
from machina.traj import Traj
from machina.traj import epi_functional as ef
from machina.samplers import EpiSampler
from machina import logger
from machina.utils import set_device, measure

class FlattenedObservationWrapper(ObservationWrapper):

    def __init__(self, env):
        super(FlattenedObservationWrapper, self).__init__(env)
        self.observation_space = env.observation_space
        self.flattend_observation_space = env.observation_space.sample().\
            reshape(-1)

    def observation(self, observation):
        return observation.reshape(-1)

class QTOptNet(nn.Module):

    def __init__(self, observation_space, action_space):
        super(QTOptNet, self).__init__()
        # conv
        self.conv1 = nn.Conv2d(4, 64, kernel_size=6, stride=2, padding=2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        # pool
        self.pool1 = nn.MaxPool2d(3, stride=3)
        self.pool2 = nn.MaxPool2d(2, stride=2)
        # fc
        self.fc1 = nn.Linear(action_space.shape[0], 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(4032, 64)
        self.fc4 = nn.Linear(64, 64)
        self.output_layer = nn.Linear(64, 1)
        self.observation_space = observation_space

    def forward(self, ob, ac):
        batch_size = ob.shape[0]
        ob = ob.reshape([batch_size] + list(self.observation_space.shape))
        ob = ob.transpose(1, 2).transpose(1, 3)  # convert to channel first
        # observation net
        ob = F.relu(self.conv1(ob))
        ob = self.pool1(ob)
        for i in range(6):
            ob = F.relu(self.conv2(ob))
        ob = self.pool1(ob)
        # action net
        ac = F.relu(self.fc1(ac))
        ac = F.relu(self.fc2(ac))
        ac = ac.view(-1, 64, 1, 1)
        # tiled layer
        ac_tiled = ac.repeat(1, 1, ob.size()[2], ob.size()[3])
        # add action feature
        h = ob + ac_tiled
        for i in range(6):
            h = F.relu(self.conv3(h))
        h = self.pool2(h)
        for i in range(3):
            h = F.relu(self.conv3(h))
        h = h.view(h.size()[0], -1)  # flatten
        h = F.relu(self.fc3(h))
        h = F.relu(self.fc4(h))
        out = torch.sigmoid(self.output_layer(h))
        return out

parser = argparse.ArgumentParser()
parser.add_argument('--log', type=str, default='garbage',
                    help='Directory name of log.')
parser.add_argument('--record', action='store_true',
                    default=False, help='If True, movie is saved.')
parser.add_argument('--seed', type=int, default=256)
parser.add_argument('--max_epis', type=int,
                    default=100000000, help='Number of episodes to run.')
parser.add_argument('--max_steps_off', type=int,
                    default=1000000000000, help='Number of episodes stored in off traj.')
parser.add_argument('--num_parallel', type=int, default=4,
                    help='Number of processes to sample.')
parser.add_argument('--cuda', type=int, default=-1, help='cuda device number.')
parser.add_argument('--data_parallel', action='store_true', default=False,
                    help='If True, inference is done in parallel on gpus.')

parser.add_argument('--max_steps_per_iter', type=int, default=4000,
                    help='Number of steps to use in an iteration.')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--pol_lr', type=float, default=1e-4,
                    help='Policy learning rate.')
parser.add_argument('--qf_lr', type=float, default=1e-3,
                    help='Q function learning rate.')
parser.add_argument('--tau', type=float, default=0.0001,
                    help='Coefficient of target function.')
parser.add_argument('--gamma', type=float, default=0.9,
                    help='Discount factor.')

parser.add_argument('--lag', type=int, default=6000,
                    help='Lag of gradient steps of target function2.')
parser.add_argument('--num_iter', type=int, default=2,
                    help='Number of iteration of CEM.')
parser.add_argument('--num_sampling', type=int, default=60,
                    help='Number of samples sampled from Gaussian in CEM.')
parser.add_argument('--num_best_sampling', type=int, default=6,
                    help='Number of best samples used for fitting Gaussian in CEM.')
parser.add_argument('--multivari', action='store_true',
                    help='If true, Gaussian with diagonal covarince instead of Multivariate Gaussian matrix is used in CEM.')
parser.add_argument('--eps', type=float, default=0.2,
                    help='Probability of random action in epsilon-greedy policy.')
parser.add_argument('--loss_type', type=str,
                    choices=['mse', 'bce'], default='mse',
                    help='Choice for type of belleman loss.')
parser.add_argument('--save_memory', action='store_true',
                    help='If true, save memory while need more computation time by for-sentence.')
args = parser.parse_args()

if not os.path.exists(args.log):
    os.mkdir(args.log)

with open(os.path.join(args.log, 'args.json'), 'w') as f:
    json.dump(vars(args), f)
pprint(vars(args))

if not os.path.exists(os.path.join(args.log, 'models')):
    os.mkdir(os.path.join(args.log, 'models'))

np.random.seed(args.seed)
torch.manual_seed(args.seed)

device_name = 'cpu' if args.cuda < 0 else "cuda:{}".format(args.cuda)
device = torch.device(device_name)
set_device(device)

score_file = os.path.join(args.log, 'progress.csv')
logger.add_tabular_output(score_file)

env = KukaCamGymEnv()
env = FlattenedObservationWrapper(env)
flattend_observation_space = env.flattend_observation_space
env = GymEnv(env, log_dir=os.path.join(
    args.log, 'movie'), record_video=args.record)
env.env.seed(args.seed)

observation_space = env.observation_space
action_space = env.action_space

qf_net = QTOptNet(observation_space, action_space)
lagged_qf_net = QTOptNet(observation_space, action_space)
lagged_qf_net.load_state_dict(qf_net.state_dict())
targ_qf1_net = QTOptNet(observation_space, action_space)
targ_qf1_net.load_state_dict(qf_net.state_dict())
targ_qf2_net = QTOptNet(observation_space, action_space)
targ_qf2_net.load_state_dict(lagged_qf_net.state_dict())
qf = DeterministicSAVfunc(observation_space, action_space, qf_net,
                          data_parallel=args.data_parallel)
lagged_qf = DeterministicSAVfunc(
    flattend_observation_space, action_space,
    lagged_qf_net,
    data_parallel=args.data_parallel)
targ_qf1 = CEMDeterministicSAVfunc(
    flattend_observation_space, action_space,
    targ_qf1_net, num_sampling=args.num_sampling,
    num_best_sampling=args.num_best_sampling, num_iter=args.num_iter,
    multivari=args.multivari, data_parallel=args.data_parallel,
    save_memory=args.save_memory)
targ_qf2 = DeterministicSAVfunc(
    flattend_observation_space, action_space, targ_qf2_net, data_parallel=args.data_parallel)

pol = ArgmaxQfPol(flattend_observation_space,
                  action_space, targ_qf1, eps=args.eps)

sampler = EpiSampler(env, pol, num_parallel=args.num_parallel, seed=args.seed)

optim_qf = torch.optim.Adam(qf_net.parameters(), args.qf_lr)

off_traj = Traj(args.max_steps_off, traj_device='cpu')

total_epi = 0
total_step = 0
total_grad_step = 0
num_update_lagged = 0
max_rew = -1e6

while args.max_epis > total_epi:
    with measure('sample'):
        epis = sampler.sample(pol, max_steps=args.max_steps_per_iter)
    with measure('train'):
        on_traj = Traj(traj_device='cpu')
        on_traj.add_epis(epis)

        on_traj = ef.add_next_obs(on_traj)
        on_traj.register_epis()

        off_traj.add_traj(on_traj)

        total_epi += on_traj.num_epi
        step = on_traj.num_step
        total_step += step
        epoch = step

        if args.data_parallel:
            qf.dp_run = True
            lagged_qf.dp_run = True
            targ_qf1.dp_run = True
            targ_qf2.dp_run = True

        result_dict = qtopt.train(
            off_traj, qf, lagged_qf, targ_qf1, targ_qf2,
            optim_qf, epoch, args.batch_size,
            args.tau, args.gamma, loss_type=args.loss_type
        )

        if args.data_parallel:
            qf.dp_run = False
            lagged_qf.dp_run = False
            targ_qf1.dp_run = False
            targ_qf2.dp_run = False

    total_grad_step += epoch
    if total_grad_step >= args.lag * num_update_lagged:
        logger.log('Updated lagged qf!!')
        lagged_qf_net.load_state_dict(qf_net.state_dict())
        num_update_lagged += 1

    rewards = [np.sum(epi['rews']) for epi in epis]
    mean_rew = np.mean(rewards)
    logger.record_results(args.log, result_dict, score_file,
                          total_epi, step, total_step,
                          rewards,
                          plot_title=args.env_name)

    if mean_rew > max_rew:
        torch.save(pol.state_dict(), os.path.join(
            args.log, 'models', 'pol_max.pkl'))
        torch.save(qf.state_dict(), os.path.join(
            args.log, 'models', 'qf_max.pkl'))
        torch.save(targ_qf1.state_dict(), os.path.join(
            args.log, 'models', 'targ_qf1_max.pkl'))
        torch.save(targ_qf2.state_dict(), os.path.join(
            args.log, 'models', 'targ_qf2_max.pkl'))
        torch.save(optim_qf.state_dict(), os.path.join(
            args.log, 'models', 'optim_qf_max.pkl'))
        max_rew = mean_rew

    torch.save(pol.state_dict(), os.path.join(
        args.log, 'models', 'pol_last.pkl'))
    torch.save(qf.state_dict(), os.path.join(
        args.log, 'models', 'qf_last.pkl'))
    torch.save(targ_qf1.state_dict(), os.path.join(
        args.log, 'models', 'targ_qf1_last.pkl'))
    torch.save(targ_qf2.state_dict(), os.path.join(
        args.log, 'models', 'targ_qf2_last.pkl'))
    torch.save(optim_qf.state_dict(), os.path.join(
        args.log, 'models', 'optim_qf_last.pkl'))
    del on_traj
del sampler
takerfume commented 5 years ago

Instead of defining your own class of FlattenedObservationWrapper as @iory says, you can also use gym.wrappers.FlattenDictWrapper (https://github.com/openai/gym/blob/master/gym/wrappers/dict.py#L8) in order to flatten observation.

yumion commented 5 years ago

Thanks @iory and @takerfume ! It works without errors until sampling, but it was killed while train. This is error codes bellow.

fish: “python run_qtopt_bulletKukaCam.…” terminated by signal SIGSEGV (Address boundary error)

I guess there is no available memory on GPU, so I decrease batch_size to 1 and num_sampling to 6. However, It works until 3rd sampling, but it was killed while 3rd train again. My environment is here.

CPU: Intel i9-7920X CPU @ 2.90GHz × 24
GPU: TITAN RTX (25GB), TITAN V(12GB) x2
RAM: 125.6GB
OS: Ubuntu 16.04

Is it something that can not be helped? Please give me advice.