Closed MariamDundua closed 3 years ago
I did not quite understand what you are trying to achieve. Could you provide full code that throws the error?
Note that we do not offer much tech support for custom hacks/modifications (these issues are for bugs/enhancements). You could also take a look at stable-baselines3 which works around PyTorch and is easier to modify.
Here is my env for training
class StockEnvTrain(gym.Env):
"""A stock trading environment for OpenAI gym"""
metadata = {'render.modes': ['human']}
def __init__(self, df):
#super(StockEnv, self).__init__()
#money = 10 , scope = 1
self.day = day
self.df = df
# action_space normalization and shape
self.action_space = spaces.Box(low = -1, high = 1,shape = (1,))
self.observation_space = spaces.Box(low=0, high=np.inf, shape = (23,))
# load data from a pandas dataframe
self.data = self.df.iloc[-1,:]
self.dataa = self.df.iloc[-4,:]
self.terminal = False
self.count=0
self.D4L_CPI_TAR=3
# initalize state
self.state =[self.data.RS.tolist()] + \
[self.data.L_GDP_RW_GAP.tolist()] + \
[self.data.DLA_CPI_RW.tolist()] + \
[self.data.L_CPI_RW.tolist()] + \
[self.data.RR_RW_BAR.tolist()] + \
[self.data.RS_RW.tolist()] + \
[self.data.RR.tolist()] + \
[self.data.RR_BAR.tolist()] + \
[self.data.RR_GAP.tolist()] + \
[self.data.DLA_Z_BAR.tolist()] + \
[self.data.L_Z_BAR.tolist()] + \
[self.data.L_Z_GAP.tolist()] + \
[self.data.PREM.tolist()] + \
[self.data.L_GDP_GAP.tolist()] + \
[self.data.DLA_CPI.tolist()] + \
[self.dataa.L_CPI.tolist()] + \
[self.data.D4L_CPI.tolist()] + \
[self.data.L_S.tolist()] + \
[self.data.L_Z.tolist()] + \
[self.data.DLA_GDP_BAR.tolist()] + \
[self.data.L_GDP_BAR.tolist()] + \
[self.data.L_GDP.tolist()] + \
[self.data.sum_deviance.tolist()]
# initialize reward
self.reward = 0
#self.cost = 0
# memorize all the total balance change
self.asset_memory = self.data.sum_deviance.tolist()
self.rewards_memory = []
#self.trades = 0
#self.reset()
self._seed()
def _sell_stock(self,action):
self.state[0] =self.df.iloc[-1].RS+ action
self.state[1] = rho_L_GDP_RW_GAP*self.df.iloc[-1].L_GDP_RW_GAP
self.state[2]= rho_DLA_CPI_RW*self.df.iloc[-1].DLA_CPI_RW + (1-rho_DLA_CPI_RW)*ss_DLA_CPI_RW
self.state[3]=0.25*self.state[2]+self.df.iloc[-1].DLA_CPI_RW
self.state[4] = rho_RR_RW_BAR*self.df.iloc[-1].RR_RW_BAR + (1-rho_RR_RW_BAR)*ss_RR_RW_BAR
self.state[5]= rho_RS_RW*self.df.iloc[-1].RS_RW + (1-rho_RS_RW)*(self.state[4] + self.state[2])
self.state[6] = (self.df.iloc[-1].RS+action) - _states[16] # 1 D4L_CPI
self.state[7] = rho_RR_BAR*self.df.iloc[-1].RR_BAR + (1-rho_RR_BAR)*ss_RR_BAR
self.state[8]= -_states[16] - self.df.iloc[-1].RR_BAR*rho_RR_BAR + (self.df.iloc[-1].RS+action) + rho_RR_BAR*ss_RR_BAR - ss_RR_BAR
self.state[9]= rho_DLA_Z_BAR*self.df.iloc[-1].DLA_Z_BAR + (1-rho_DLA_Z_BAR)*ss_DLA_Z_BAR
self.state[10]=0.25*self.df.iloc[-1].DLA_Z_BAR*rho_DLA_Z_BAR + self.df.iloc[-1].L_Z_BAR - 0.25*rho_DLA_Z_BAR*ss_DLA_Z_BAR + 0.25*ss_DLA_Z_BAR
self.state[11]= (-2.0*self.D4L_CPI_TAR*e1 - _states[14]*a1 + _states[14] + self.df.iloc[-1].DLA_CPI*a1 + _states[9] - 4.0*self.state[3] + 4.0* self.df.iloc[-1].L_CPI + 4.0*_states[17]*e1 - 4.0*_states[17] - 4.0*self.df.iloc[-1].L_S*e1 + 4.0*self.df.iloc[-1].L_Z_BAR + self.state[4] + (self.df.iloc[-1].RS+action) - self.state[5] + a2*a3*b2*b4*(_states[16] - (self.df.iloc[-1].RS+action)) + a2*a3*(self.state._states[13]*b1 + self.state[1]*b3) + 2.0*e1*ss_DLA_CPI_RW - (2.0*e1 - 1.0)*(self.df.iloc[-1].DLA_Z_BAR*rho_DLA_Z_BAR - rho_DLA_Z_BAR*ss_DLA_Z_BAR + ss_DLA_Z_BAR) + (a2*a3*b2*b4 - 1.0)*(self.df.iloc[-1].RR_BAR*rho_RR_BAR - rho_RR_BAR*ss_RR_BAR + ss_RR_BAR))/(a2*a3*b2*(b4 - 1.0) + a2*a3 - a2 - 4.0)
self.state[12]=-_states[9] + self.df.iloc[-1].RR_BAR*rho_RR_BAR - self.state[4] - self.state.rho_RR_BAR*ss_RR_BAR + ss_RR_BAR
self.state[13] = (4.0*self.state[3]*b2*(b4 - 1.0) - 4.0* self.df.iloc[-1].L_CPI*b2*(b4 - 1.0) - 4.0*self.df.iloc[-1].L_Z_BAR*b2*(b4 - 1.0) - b2*b4*(_states[16] - (self.df.iloc[-1].RS+action))*(-a2*a3 + a2 + 4.0) - b2*(_states[9] + self.state[4])*(b4 - 1.0) - b2*(b4 - 1.0)*(-_states[14]*a1 + _states[14]+ _states[14]*a1) + b2*(b4 - 1.0)*(2.0*self.D4L_CPI_TAR*e1 - 4.0*_states[17]*e1 + 4.0*_states[17] + 4.0*self.df.iloc[-1].L_S*e1 - (self.df.iloc[-1].RS+action) + self.state[5] - 2.0*e1*ss_DLA_CPI_RW) + b2*(-b4 + 2.0*e1*(b4 - 1.0) + 1.0)*(self.df.iloc[-1].DLA_Z_BAR*rho_DLA_Z_BAR - rho_DLA_Z_BAR*ss_DLA_Z_BAR + ss_DLA_Z_BAR) - b2*(self.df.iloc[-1].RR_BAR*rho_RR_BAR - rho_RR_BAR*ss_RR_BAR + ss_RR_BAR)*(-a2*a3*b4 + a2*b4 + 3.0*b4 + 1.0) - (_states[13]*b1 + self.state[1]*b3)*(-a2*a3 + a2 + 4.0))/(a2*a3*b2*(b4 - 1.0) + a2*a3 - a2 - 4.0)
self.state[14] = (4.0*_states[14]*a1 - 4.0*_states[14] - 4.0*self.df.iloc[-1].DLA_CPI*a1 + 4.0*self.state[3]*a2*(a3*b2*(b4 - 1.0) + a3 - 1.0) - 4.0* self.df.iloc[-1].L_CPI*a2*(a3*b2*(b4 - 1.0) + a3 - 1.0) - 4.0*self.df.iloc[-1].L_Z_BAR*a2*(a3*b2*(b4 - 1.0) + a3 - 1.0) - 4.0*a2*a3*b2*b4*(_states[16] - (self.df.iloc[-1].RS+action)) - 4.0*a2*a3*(_states[13]*b1 + self.state[1]*b3) - a2*(_states[9] + self.state[4])*(a3*b2*(b4 - 1.0) + a3 - 1.0) + a2*(self.df.iloc[-1].DLA_Z_BAR*rho_DLA_Z_BAR - rho_DLA_Z_BAR*ss_DLA_Z_BAR + ss_DLA_Z_BAR)*(2.0*a3*b2*e1*(b4 - 1.0) - a3*b2*(b4 - 1.0) - a3 + 2.0*e1*(a3 - 1.0) + 1.0) - a2*(self.df.iloc[-1].RR_BAR*rho_RR_BAR - rho_RR_BAR*ss_RR_BAR + ss_RR_BAR)*(4.0*a3*b2*b4 - a3*b2*(b4 - 1.0) - a3 + 1.0) + a2*(a3*b2*(b4 - 1.0) + a3 - 1.0)*(2.0*self.D4L_CPI_TAR*e1 - 4.0*L_S_1*e1 + 4.0*L_S_1 + 4.0*L_S__1*e1 - (self.df.iloc[-1].RS+action) + self.state[5] - 2.0*e1*ss_DLA_CPI_RW))/(a2*a3*b2*(b4 - 1.0) + a2*a3 - a2 - 4.0)
self.state[15]=(_states[14]*a1 - _states[14] - self.df.iloc[-1].DLA_CPI*a1 + self.state[3]*a2*(a3*b2*(b4 - 1.0) + a3 - 1.0) - 4.0* self.df.iloc[-1].L_CPI - self.df.iloc[-1].L_Z_BAR*a2*(a3*b2*(b4 - 1.0) + a3 - 1.0) - a2*a3*b2*b4*(_states[16] - (self.df.iloc[-1].RS+action)) - a2*a3*(_states[13]*b1 + self.state[1]*b3) - 0.25*a2*(_states[9] + self.state[4])*(a3*b2*(b4 - 1.0) + a3 - 1.0) + 0.25*a2*(self.df.iloc[-1].DLA_Z_BAR*rho_DLA_Z_BAR - rho_DLA_Z_BAR*ss_DLA_Z_BAR + ss_DLA_Z_BAR)*(2.0*a3*b2*e1*(b4 - 1.0) - a3*b2*(b4 - 1.0) - a3 + 2.0*e1*(a3 - 1.0) + 1.0) - 0.25*a2*(self.df.iloc[-1].RR_BAR*rho_RR_BAR - rho_RR_BAR*ss_RR_BAR + ss_RR_BAR)*(4.0*a3*b2*b4 - a3*b2*(b4 - 1.0) - a3 + 1.0) + 0.25*a2*(a3*b2*(b4 - 1.0) + a3 - 1.0)*(2.0*self.D4L_CPI_TAR*e1 - 4.0*_states[17]*e1 + 4.0*_states[17] + 4.0*self.df.iloc[-1].L_S*e1 - (self.df.iloc[-1].RS+action) + self.state[5] - 2.0*e1*ss_DLA_CPI_RW))/(a2*a3*b2*(b4 - 1.0) + a2*a3 - a2 - 4.0)
self.state[16]=(_states[14]*a1 - _states[14] - self.df.iloc[-1].DLA_CPI*a1 + self.state[3]*a2*(a3*b2*(b4 - 1.0) + a3 - 1.0) - 4.0* self.df.iloc[-1].L_CPI - self.df.iloc[-4].L_CPI*(a2*a3*b2*(b4 - 1.0) + a2*a3 - a2 - 4.0) - self.df.iloc[-1].L_Z_BAR*a2*(a3*b2*(b4 - 1.0) + a3 - 1.0) - a2*a3*b2*b4*(_states[16] - (self.df.iloc[-1].RS+action)) - a2*a3*(_states[13]*b1 + self.state[1]*b3) - 0.25*a2*(self.df.iloc[-1].DLA_Z_BAR + self.state[4])*(a3*b2*(b4 - 1.0) + a3 - 1.0) + 0.25*a2*(self.df.iloc[-1].DLA_Z_BAR*rho_DLA_Z_BAR - rho_DLA_Z_BAR*ss_DLA_Z_BAR + ss_DLA_Z_BAR)*(2.0*a3*b2*e1*(b4 - 1.0) - a3*b2*(b4 - 1.0) - a3 + 2.0*e1*(a3 - 1.0) + 1.0) - 0.25*a2*(self.df.iloc[-1].RR_BAR*rho_RR_BAR - rho_RR_BAR*ss_RR_BAR + ss_RR_BAR)*(4.0*a3*b2*b4 - a3*b2*(b4 - 1.0) - a3 + 1.0) + 0.25*a2*(a3*b2*(b4 - 1.0) + a3 - 1.0)*(2.0*self.D4L_CPI_TAR*e1 - 4.0*_states[17]*e1 + 4.0*_states[17] + 4.0*self.df.iloc[-1].L_S*e1 - (self.df.iloc[-1].RS+action) + self.state[5] - 2.0*e1*ss_DLA_CPI_RW))/(a2*a3*b2*(b4 - 1.0) + a2*a3 - a2 - 4.0)
self.state[17] = 0.5*self.D4L_CPI_TAR*e1 - 0.25*_states[9] + 0.5*self.df.iloc[-1].DLA_Z_BAR*e1*rho_DLA_Z_BAR - _states[17]*e1 + _states[17] + self.df.iloc[-1].L_S*e1 + 0.25*self.df.iloc[-1].RR_BAR*rho_RR_BAR - 0.25*self.state[4] - 0.25*(self.df.iloc[-1].RS+action) + 0.25*self.state[5] - 0.5*e1*rho_DLA_Z_BAR*ss_DLA_Z_BAR - 0.5*e1*ss_DLA_CPI_RW + 0.5*e1*ss_DLA_Z_BAR - 0.25*rho_RR_BAR*ss_RR_BAR + 0.25*ss_RR_BAR
self.state[18] = (-2.0*self.D4L_CPI_TAR*e1 - _states[14]*a1 + _states[14] + self.df.iloc[-1].DLA_CPI*a1 + _states[9] - 4.0*self.state[3] + 4.0* self.df.iloc[-1].L_CPI + 4.0*_states[17]*e1 - 4.0*_states[17] - 4.0*self.df.iloc[-1].L_S*e1 + self.df.iloc[-1].L_Z_BAR*a2*(a3*b2*(b4 - 1.0) + a3 - 1.0) + self.state[4] + (self.df.iloc[-1].RS+action) - self.state[5] + a2*a3*b2*b4*(_states[16] - (self.df.iloc[-1].RS+action)) + a2*a3*(_states[13]*b1 + self.state[1]*b3) + 2.0*e1*ss_DLA_CPI_RW + (a2*a3*b2*b4 - 1.0)*(self.df.iloc[-1].RR_BAR*rho_RR_BAR - rho_RR_BAR*ss_RR_BAR + ss_RR_BAR) + 0.25*(self.df.iloc[-1].DLA_Z_BAR*rho_DLA_Z_BAR - rho_DLA_Z_BAR*ss_DLA_Z_BAR + ss_DLA_Z_BAR)*(a2*a3*b2*(b4 - 1.0) + a2*a3 - a2 - 8.0*e1))/(a2*a3*b2*(b4 - 1.0) + a2*a3 - a2 - 4.0)
self.state[19] = rho_DLA_GDP_BAR*self.df.iloc[-1].DLA_GDP_BAR + (1-rho_DLA_GDP_BAR)*ss_DLA_GDP_BAR
self.state[20]=0.25*self.state[19]+self.df.iloc[-1].L_GDP_BAR
self.state[21]=self.state[13]+self.state[20]
self.state[22]=self.data.sum_deviance.values ###შეიძლება შეცვლა
self.df=self.df.append({'RS':self.state[0],'L_GDP_RW_GAP':self.state[1],'DLA_CPI_RW':self.state[2],'L_CPI_RW':self.state[3],'RR_RW_BAR':self.state[4],'RS_RW':self.state[5],'RR':self.state[6],'RR_BAR':self.state[7],'RR_GAP':self.state[8],'DLA_Z_BAR':self.state[9],'L_Z_BAR':self.state[10],'L_Z_GAP':self.state[11],'PREM':self.state[12],'L_GDP_GAP':self.state[13],'DLA_CPI':self.state[14],'L_CPI':self.state[15],'D4L_CPI':self.state[16],'L_S':self.state[17],'L_Z':self.state[18],'DLA_GDP_BAR':self.state[19],'L_GDP_BAR':self.state[20],'L_GDP':self.state[21],'sum_deviance':self.state[22]},ignore_index = True)
def step(self, action):
# print(self.day)
self.terminal = self.count>1000
# print(actions)
if self.terminal:
plt.plot(self.asset_memory,'r')
plt.savefig('/content/account_value_train.png')
plt.close()
end_total_asset = self.state[22]
#print("end_total_asset:{}".format(end_total_asset))
df_total_value = pd.DataFrame(self.asset_memory)
df_total_value.to_csv('/content/account_value_train.csv')
#print("total_reward:{}".format(self.state[0]+sum(np.array(self.state[1:(STOCK_DIM+1)])*np.array(self.state[(STOCK_DIM+1):61]))- INITIAL_ACCOUNT_BALANCE ))
#print("total_cost: ", self.cost)
#print("total_trades: ", self.trades)
df_total_value.columns = ['account_value']
#df_total_value['daily_return']=df_total_value.pct_change(1)
#sharpe = (252**0.5)*df_total_value['daily_return'].mean()/ \
# df_total_value['daily_return'].std()
#print("Sharpe: ",sharpe)
#print("=================================")
df_rewards = pd.DataFrame(self.rewards_memory)
#df_rewards.to_csv('results/account_rewards_train.csv')
# print('total asset: {}'.format(self.state[0]+ sum(np.array(self.state[1:29])*np.array(self.state[29:]))))
#with open('obs.pkl', 'wb') as f:
# pickle.dump(self.state, f)
return self.state, self.reward, self.terminal,{}
else:
# print(np.array(self.state[1:29]))
#action = action[0][0]
#_states=_states[0]
#actions = (actions.astype(int))
begin_total_asset = np.var(df.D4L_CPI.iloc[43:])+np.var(df.L_GDP.iloc[43:])
#print("begin_total_asset:{}".format(begin_total_asset))
#argsort_actions = np.argsort(actions)
#sell_index = argsort_actions[:np.where(actions < 0)[0].shape[0]]
#buy_index = argsort_actions[::-1][:np.where(actions > 0)[0].shape[0]]
self._sell_stock(action)
self.count+=1
#self.day =-1
self.data = self.df.iloc[-1,:]
self.dataa=self.df.iloc[-4,:]
#load next state
# print("stock_shares:{}".format(self.state[29:]))
self.state =[self.data.RS.tolist()] + \
[self.data.L_GDP_RW_GAP.tolist()] + \
[self.data.DLA_CPI_RW.tolist()] + \
[self.data.L_CPI_RW.tolist()] + \
[self.data.RR_RW_BAR.tolist()] + \
[self.data.RS_RW.tolist()] + \
[self.data.RR.tolist()] + \
[self.data.RR_BAR.tolist()] + \
[self.data.RR_GAP.tolist()] + \
[self.data.DLA_Z_BAR.tolist()] + \
[self.data.L_Z_BAR.tolist()] + \
[self.data.L_Z_GAP.tolist()] + \
[self.data.PREM.tolist()] + \
[self.data.L_GDP_GAP.tolist()] + \
[self.data.DLA_CPI.tolist()] + \
[self.dataa.L_CPI.tolist()] + \
[self.data.D4L_CPI.tolist()] + \
[self.data.L_S.tolist()] + \
[self.data.L_Z.tolist()] + \
[self.data.DLA_GDP_BAR.tolist()] + \
[self.data.L_GDP_BAR.tolist()] + \
[self.data.L_GDP.tolist()] + \
[self.data.sum_deviance.tolist()]
end_total_asset =np.var(df.D4L_CPI.iloc[43:])+np.var(df.L_GDP.iloc[43:])
self.asset_memory.append(end_total_asset)
#print("end_total_asset:{}".format(end_total_asset))
self.reward = -1*(end_total_asset - begin_total_asset)
# print("step_reward:{}".format(self.reward))
self.rewards_memory.append(self.reward)
self.reward = self.reward
return self.state, self.reward, self.terminal, {}
def reset(self):
self.asset_memory =self.data.sum_deviance.tolist()
#self.day = -1
self.data = self.df.iloc[-1,:]
self.dataa = self.df.iloc[-4,:]
self.cost = 0
self.trades = 0
self.terminal = False
self.rewards_memory = []
self.count=0
#initiate state
self.state =[self.data.RS.tolist()] + \
[self.data.L_GDP_RW_GAP.tolist()] + \
[self.data.DLA_CPI_RW.tolist()] + \
[self.data.L_CPI_RW.tolist()] + \
[self.data.RR_RW_BAR.tolist()] + \
[self.data.RS_RW.tolist()] + \
[self.data.RR.tolist()] + \
[self.data.RR_BAR.tolist()] + \
[self.data.RR_GAP.tolist()] + \
[self.data.DLA_Z_BAR.tolist()] + \
[self.data.L_Z_BAR.tolist()] + \
[self.data.L_Z_GAP.tolist()] + \
[self.data.PREM.tolist()] + \
[self.data.L_GDP_GAP.tolist()] + \
[self.data.DLA_CPI.tolist()] + \
[self.dataa.L_CPI.tolist()] + \
[self.data.D4L_CPI.tolist()] + \
[self.data.L_S.tolist()] + \
[self.data.L_Z.tolist()] + \
[self.data.DLA_GDP_BAR.tolist()] + \
[self.data.L_GDP_BAR.tolist()] + \
[self.data.L_GDP.tolist()] + \
[self.data.sum_deviance.tolist()]
# iteration += 1
return self.state
def render(self, mode='human'):
return self.state
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
After I am trying to train the model
env_train = DummyVecEnv([lambda: StockEnvTrain(df)])
obs_train = env_train.reset()
model = A2C('MlpPolicy', env_train, verbose=0)
model.learn(total_timesteps=25000)
start = time.time()
#model = PPO2('MlpPolicy', env_train, ent_coef = 0.005, nminibatches = 8)
#model.learn(total_timesteps=200)
end = time.time()
for i in range(1000):
action, _states = model.predict(obs_trade)
obs_trade, rewards, dones, info = env_trade.step(action)
if i == 10000:
last_state = env_trade.render()
It gives me error:
/usr/local/lib/python3.7/dist-packages/stable_baselines/a2c/a2c.py in learn(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps)
261 callback.on_rollout_start()
262 # true_reward is the reward without discount
--> 263 rollout = self.runner.run(callback)
264 # unpack
265 obs, states, rewards, masks, actions, values, ep_infos, true_reward = rollout
/usr/local/lib/python3.7/dist-packages/stable_baselines/common/runners.py in run(self, callback)
46 self.callback = callback
47 self.continue_training = True
---> 48 return self._run()
49
50 @abstractmethod
/usr/local/lib/python3.7/dist-packages/stable_baselines/a2c/a2c.py in _run(self)
359 if isinstance(self.env.action_space, gym.spaces.Box):
360 clipped_actions = np.clip(actions, self.env.action_space.low, self.env.action_space.high)
--> 361 obs, rewards, dones, infos = self.env.step(clipped_actions)
362
363 self.model.num_timesteps += self.n_envs
/usr/local/lib/python3.7/dist-packages/stable_baselines/common/vec_env/base_vec_env.py in step(self, actions)
148 """
149 self.step_async(actions)
--> 150 return self.step_wait()
151
152 def get_images(self) -> Sequence[np.ndarray]:
/usr/local/lib/python3.7/dist-packages/stable_baselines/common/vec_env/dummy_vec_env.py in step_wait(self)
42 for env_idx in range(self.num_envs):
43 obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] =\
---> 44 self.envs[env_idx].step(self.actions[env_idx])
45 if self.buf_dones[env_idx]:
46 # save final observation where user can get it, then reset
<ipython-input-90-faf7c1a0a093> in step(self, action)
240
241
--> 242 self._sell_stock(action)
243
244
<ipython-input-90-faf7c1a0a093> in _sell_stock(self, action)
169 self.state[5]= rho_RS_RW*self.df.iloc[-1].RS_RW + (1-rho_RS_RW)*(self.state[4] + self.state[2])
170
--> 171 self.state[6] = (self.df.iloc[-1].RS+action) - _states[16] # 1 D4L_CPI
172 self.state[7] = rho_RR_BAR*self.df.iloc[-1].RR_BAR + (1-rho_RR_BAR)*ss_RR_BAR
173 self.state[8]= -_states[16] - self.df.iloc[-1].RR_BAR*rho_RR_BAR + (self.df.iloc[-1].RS+action) + rho_RR_BAR*ss_RR_BAR - ss_RR_BAR
TypeError: 'NoneType' object is not subscriptable
I realized from this error that _states[16]
is NoneType and I wanted to use it as a prediction of a future state
That does not seem to have anything to do with stable-baselines (a bug in the environment). Please only post issues here if you spot a bug in the stable-baselines library itself.
Closing as "no tech support".
I am using for RL model stable baseline library. When I am constructing state of the model, in one place I need the prediction about the futur state. I decided to use models expected next state. Example is below:
In order to construct env , I need information about
_states
from model.predict(obs). So my problem is put_states
information in my current state. When I set_states
in my state, the training of model gives me error:TypeError: 'NoneType' object is not subscriptable
I wrote in my state_states[0]
, I need a prediction of the first state. From My error message, I realize, that _states
is 'NoneType'. Where I am wrong?Here is the small version of my state : Initial state:
df
is pandas dataframe.RS
andL_GDP_RW_GAP
are df's columns When I want to update the state I am using: