Closed Jiankai-Sun closed 6 years ago
Hi @Jiankai-Sun ! Thanks for pointing that out. It should be now fixed. It was an issue when saving the model that caused this. Shouldn't be a problem anymore.
Thank you for your reply! However, "saving the model" is conducted by test.py
. From the pictures above, we can see that it is Training Agent:
which is probably from train.py
that occupies 2 gpus. I've made some changes based on your work, is there any other possible reasons which trigger the problem?
My code:
test.py
from __future__ import division
from setproctitle import setproctitle as ptitle
import torch
from environment import atari_env
from utils import setup_logger, normalize_rgb_obs, frame_to_video
from model import A3Clstm
from player_util import Agent
import time
import logging
from tensorboardX import SummaryWriter
import cv2, os, shutil
def test(rank, args, shared_model):
ptitle('Test Agent')
gpu_id = args.gpu_ids[rank % len(args.gpu_ids)]
writer = SummaryWriter(log_dir=args.log_dir+'tb_test')
log = {}
setup_logger('{}_log'.format('Test_'+str(rank)),
r'{0}{1}_log'.format(args.log_dir, 'Test_'+str(rank)))
log['{}_log'.format('Test_'+str(rank))] = logging.getLogger(
'{}_log'.format('Test_'+str(rank)))
d_args = vars(args)
for k in d_args.keys():
log['{}_log'.format('Test_'+str(rank))].info('{0}: {1}'.format(k, d_args[k]))
torch.manual_seed(args.seed)
if gpu_id >= 0:
torch.cuda.manual_seed(args.seed)
env = atari_env(env_id=rank, args=args, type='train')
reward_sum = 0
start_time = time.time()
num_tests = 0
num_inside_target_room = 0
reward_total_sum = 0
player = Agent(None, env, args, None)
player.gpu_id = gpu_id
player.model = A3Clstm(
player.env.observation_space.shape[2], player.env.action_space.n)
player.state = player.env.reset()
player.state = normalize_rgb_obs(player.state)
player.state = torch.from_numpy(player.state).float()
if gpu_id >= 0:
with torch.cuda.device(gpu_id):
player.model = player.model.cuda()
player.state = player.state.cuda()
player.model.eval()
action_times = 0
while True:
action_times += 1
if player.done:
if gpu_id >= 0:
with torch.cuda.device(gpu_id):
player.model.load_state_dict(shared_model.state_dict())
else:
player.model.load_state_dict(shared_model.state_dict())
player.action_test()
reward_sum += player.reward
if not os.path.exists(args.log_dir + "video/" + str(rank) + "_" + str(num_tests)):
os.makedirs(args.log_dir + "video/" + str(rank) + "_" + str(num_tests))
cv2.imwrite(args.log_dir + "video/" + str(rank) + "_" + str(num_tests) + "/" + str(action_times) + ".png",
player.env.get_rgb()) # (90, 120, 3)
if player.done:
frame_to_video(fileloc=args.log_dir + "video/" + str(rank) + "_" + str(num_tests) + "/%d.png", t_w=120, t_h=90,
destination=args.log_dir + "video/" + str(rank) + "_" + str(num_tests) + ".mp4")
shutil.rmtree(args.log_dir + "video/" + str(rank) + "_" + str(num_tests))
action_times = 0
num_tests += 1
num_inside_target_room += player.env.inside_target_room
reward_total_sum += reward_sum
reward_mean = reward_total_sum / num_tests
success_rate = num_inside_target_room / num_tests
log['{}_log'.format('Test_'+str(rank))].info(
"Time {0}, Tester {1}, Test {2}, episode reward {3}, episode length {4}, reward mean {5:.4f}, success rate {6}".
format(
time.strftime("%Hh %Mm %Ss",
time.gmtime(time.time() - start_time)), rank,
num_tests, reward_sum, player.eps_len, reward_mean, success_rate))
# Tensorboard
writer.add_scalar("data/episode_reward", reward_sum, num_tests)
writer.add_scalar("data/episode_length", player.eps_len, num_tests)
writer.add_scalar("data/reward_mean", reward_mean, num_tests)
writer.add_scalar("data/success_rate", success_rate, num_tests)
if reward_sum > args.save_score_level:
# player.model.load_state_dict(shared_model.state_dict())
if gpu_id >= 0:
with torch.cuda.device(gpu_id):
state_to_save = player.model.state_dict()
torch.save(state_to_save, '{0}{1}_{2}.dat'.format(
args.save_model_dir, 'Test_' + str(rank), reward_sum))
else:
state_to_save = player.model.state_dict()
torch.save(state_to_save, '{0}{1}_{2}.dat'.format(
args.save_model_dir, 'Test_'+str(rank), reward_sum))
reward_sum = 0
player.eps_len = 0
state = player.env.reset()
time.sleep(10)
state = normalize_rgb_obs(state)
player.state = torch.from_numpy(state).float()
if gpu_id >= 0:
with torch.cuda.device(gpu_id):
player.state = player.state.cuda()
train.py
from __future__ import division
from setproctitle import setproctitle as ptitle
import torch
import torch.optim as optim
from environment import atari_env
from utils import ensure_shared_grads, normalize_rgb_obs
from model import A3Clstm
from player_util import Agent
from torch.autograd import Variable
from utils import setup_logger
import logging, os
from tensorboardX import SummaryWriter
from torchvision.utils import save_image
def train(rank, args, shared_model, optimizer):
ptitle('Training Agent: {}'.format(rank))
gpu_id = args.gpu_ids[rank % len(args.gpu_ids)]
writer = SummaryWriter(log_dir=args.log_dir+'tb_train')
log = {}
setup_logger('{}_train_log'.format(rank),
r'{0}{1}_train_log'.format(args.log_dir, rank))
log['{}_train_log'.format(rank)] = logging.getLogger(
'{}_train_log'.format(rank))
torch.manual_seed(args.seed + rank)
if gpu_id >= 0:
torch.cuda.manual_seed(args.seed + rank)
env = atari_env(env_id=rank, args=args, type='train')
if optimizer is None:
if args.optimizer == 'RMSprop':
optimizer = optim.RMSprop(shared_model.parameters(), lr=args.lr)
if args.optimizer == 'Adam':
optimizer = optim.Adam(
shared_model.parameters(), lr=args.lr, amsgrad=args.amsgrad)
env.seed(args.seed + rank)
player = Agent(None, env, args, None)
player.gpu_id = gpu_id
player.model = A3Clstm(
player.env.observation_space.shape[2], player.env.action_space.n)
player.state = player.env.reset()
player.state = normalize_rgb_obs(player.state)
player.state = torch.from_numpy(player.state).float()
if gpu_id >= 0:
with torch.cuda.device(gpu_id):
player.state = player.state.cuda()
player.model = player.model.cuda()
player.model.train()
num_trains = 0
if not os.path.exists(args.log_dir + "images/"):
os.makedirs(args.log_dir + "images/")
while True:
if gpu_id >= 0:
with torch.cuda.device(gpu_id):
player.model.load_state_dict(shared_model.state_dict())
else:
player.model.load_state_dict(shared_model.state_dict())
for step in range(args.num_steps):
player.action_train()
if player.done:
break
if player.done:
num_trains += 1
log['{}_train_log'.format(rank)].info('entropy:{0}'.format(player.entropy.data[0]))
writer.add_scalar("data/entropy_" + str(rank), player.entropy.data[0], num_trains)
writer.add_image('FCN', player.fcn, num_trains)
writer.add_image('Depth_GroundTruth', player.depth, num_trains)
player.eps_len = 0
player.current_life = 0
state = player.env.reset()
state = normalize_rgb_obs(state)
player.state = torch.from_numpy(state).float()
if gpu_id >= 0:
with torch.cuda.device(gpu_id):
player.state = player.state.cuda()
R = torch.zeros(1, 1)
if not player.done:
value, _, _, _ = player.model(
(Variable(player.state.unsqueeze(0)), (player.hx, player.cx),
Variable(torch.from_numpy(player.env.target).type(torch.FloatTensor).cuda())))
R = value.data
if gpu_id >= 0:
with torch.cuda.device(gpu_id):
R = R.cuda()
player.values.append(Variable(R))
policy_loss = 0
value_loss = 0
gae = torch.zeros(1, 1)
if gpu_id >= 0:
with torch.cuda.device(gpu_id):
gae = gae.cuda()
R = Variable(R)
for i in reversed(range(len(player.rewards))):
R = args.gamma * R + player.rewards[i]
advantage = R - player.values[i]
value_loss = value_loss + 0.5 * advantage.pow(2)
# Generalized Advantage Estimataion
delta_t = args.gamma * player.values[i + 1].data + player.rewards[i] - player.values[i].data
gae = gae * args.gamma * args.tau + delta_t
policy_loss = policy_loss - \
player.log_probs[i] * \
Variable(gae) - 0.01 * player.entropies[i] \
+ player.fcn_losses[i] # FCN
writer.add_scalar("data/value_loss_" + str(rank), value_loss, num_trains)
player.model.zero_grad()
(policy_loss + 0.5 * value_loss).backward()
torch.nn.utils.clip_grad_norm(player.model.parameters(), 40.0)
ensure_shared_grads(player.model, shared_model, gpu=gpu_id >= 0)
optimizer.step()
player.clear_actions()
From the picture below, we can see that process 7926 occupy 2 gpus respectively. Process 7928, 9665, 9669 are also same.
Yup you have a lot of additions and differences to mine. Don't have this issue on my end but pretty certain your problem is here:
if not player.done:
value, _, _, _ = player.model(
(Variable(player.state.unsqueeze(0)), (player.hx, player.cx),
Variable(torch.from_numpy(player.env.target).type(torch.FloatTensor).cuda())))
R = value.data
When you call cuda there its gonna go to your default GPU device which is presumably GPU 0. Just amend to match my code there and you should be good. Or if you still want to call cuda there you can but need to assign to GPU device you want but you shouldn't have to if you have it already assigned to right gpu in the player_utils.py file as in repo.
Hope that helps.
Thank you for your help. The problem has been solved!
First, thank you for your great work of a3c implementation.
I run the code with
python main.py --workers 1 --gpu-ids 5
and find out that one process runs on 2 gpus. Similar things happened when I run with--workers 50
. All the processes should run on gpu 5. However, I find that all of these processes (same PID) run on gpu 0 withType C
and smallerGPU Memory Usage
compared with those run on gpu 5. How can I assign all the processes on gpu 5? Thank you very much!