Hi @MorvanZhou. Thank you for your tutorial. I'm trying to modify the A3C from Cartpole to MsPacman. I found that after I change the network to a CNN, the code will get stuck in the forward function. It could be run on Mac without problems. But It will get stuck when running on Linux. To illustrate the problem, I simply changed the N_S to 10000 in discrete_A3C.py and use a randomly generated numpy vector as a state. It will also stuck in forward function and has no any warning or error information. Do you have any ideas about that?
"""
Reinforcement Learning (A3C) using Pytroch + multiprocessing.
The most simple implementation for continuous action.
View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
"""
import torch
import numpy as np
import torch.nn as nn
from utils import v_wrap, set_init, push_and_pull, record
import torch.nn.functional as F
import torch.multiprocessing as mp
from shared_adam import SharedAdam
import gym
import os
os.environ["OMP_NUM_THREADS"] = "1"
UPDATE_GLOBAL_ITER = 10
GAMMA = 0.9
MAX_EP = 4000
env = gym.make('CartPole-v0')
N_S = 10000
N_A = env.action_space.n
class Net(nn.Module):
def __init__(self, s_dim, a_dim):
super(Net, self).__init__()
self.s_dim = s_dim
self.a_dim = a_dim
self.pi1 = nn.Linear(s_dim, 100)
self.pi2 = nn.Linear(100, a_dim)
self.v1 = nn.Linear(s_dim, 100)
self.v2 = nn.Linear(100, 1)
set_init([self.pi1, self.pi2, self.v1, self.v2])
self.distribution = torch.distributions.Categorical
def forward(self, x):
pi1 = F.relu(self.pi1(x))
logits = self.pi2(pi1)
v1 = F.relu(self.v1(x))
values = self.v2(v1)
return logits, values
def choose_action(self, s):
self.eval()
logits, _ = self.forward(s)
prob = F.softmax(logits, dim=1).data
m = self.distribution(prob)
return m.sample().numpy()[0]
def loss_func(self, s, a, v_t):
self.train()
logits, values = self.forward(s)
td = v_t - values
c_loss = td.pow(2)
probs = F.softmax(logits, dim=1)
m = self.distribution(probs)
exp_v = m.log_prob(a) * td.detach().squeeze()
a_loss = -exp_v
total_loss = (c_loss + a_loss).mean()
return total_loss
class Worker(mp.Process):
def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name):
super(Worker, self).__init__()
self.name = 'w%i' % name
self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
self.gnet, self.opt = gnet, opt
self.lnet = Net(N_S, N_A) # local network
self.env = gym.make('MsPacman-v0').unwrapped
def run(self):
total_step = 1
while self.g_ep.value < MAX_EP:
s = self.env.reset()
s = np.random.rand(N_S)
buffer_s, buffer_a, buffer_r = [], [], []
ep_r = 0.
while True:
if self.name == 'w0':
self.env.render()
a = self.lnet.choose_action(v_wrap(s[None, :]))
s_, r, done, _ = self.env.step(a)
s_ = np.random.rand(N_S)
if done: r = -1
ep_r += r
buffer_a.append(a)
buffer_s.append(s)
buffer_r.append(r)
if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net
# sync
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
buffer_s, buffer_a, buffer_r = [], [], []
if done: # done and print information
record(self.g_ep, ep_r, self.res_queue, self.name, 1, 0)
break
s = s_
total_step += 1
self.res_queue.put(None)
if __name__ == "__main__":
gnet = Net(N_S, N_A) # global network
gnet.share_memory() # share the global parameters in multiprocessing
opt = SharedAdam(gnet.parameters(), lr=0.0001) # global optimizer
global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue()
# parallel training
workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i) for i in range(mp.cpu_count())]
[w.start() for w in workers]
res = [] # record episode reward to plot
while True:
r = res_queue.get()
if r is not None:
res.append(r)
else:
break
[w.join() for w in workers]
import matplotlib.pyplot as plt
plt.plot(res)
plt.ylabel('Moving average ep reward')
plt.xlabel('Step')
plt.show()
I encountered a similar problem long time ago. It will be stuck when using a large numpy array with multiprocessing. I haven't found any useful method to deal with this problem.
The problem could be solved by add mp.set_start_method("spawn") to the beginning of the if __name__ == '__main__' scope. The answer is referred from here.
Hi @MorvanZhou. Thank you for your tutorial. I'm trying to modify the A3C from Cartpole to MsPacman. I found that after I change the network to a CNN, the code will get stuck in the forward function. It could be run on Mac without problems. But It will get stuck when running on Linux. To illustrate the problem, I simply changed the N_S to 10000 in discrete_A3C.py and use a randomly generated numpy vector as a state. It will also stuck in forward function and has no any warning or error information. Do you have any ideas about that?