Closed AkashKarnatak closed 1 year ago
have you tested it? does it works?
yes, here are the results:
>>> env = StockTradingEnv(df = train, **env_kwargs)
>>> print('Day 0:', train.loc[0].close.values)
Day 0: [ 19.94001961 61.10700989 222.36936951 79.38145447 97.45033264
144.63755798 15.24585056 5.74729586 303.35714722 33.54574585
105.43174744 178.36380005 292.96624756 415.73022461 19.91226387
166.92402649 22.62735176 251.48953247 91.10407257 493.23080444
49.65197372 189.57377625 71.15727234 37.47664261 102.81951904
41.22008896 18.54531288 96.87306976 300.88070679 64.68822479
489.25418091 443.50683594 89.87454987 62.28596878 40.28764725
287.0161438 111.89730072 94.98454285 50.7267189 32.15766907
7.10770512 94.18431854 54.51696396 43.74747086 364.16394043
65.57043457 47.51332855]
>>> last_day = len(train.date.unique())
>>> print(f'Day {last_day}:', train.loc[last_day-1].close.values)
Day 2577: [ 148.79171753 399.91134644 1342.78930664 1320.36901855
806.31970215 2465.26635742 851.19567871 3647.03857422
343.90393066 299.88269043 2462.59765625 543.18145752
1550.21484375 2458.41845703 1718.94885254 894.87109375
472.84255981 2033.78674316 1186.35107422 2243.98901367
201.33979797 1682.8092041 431.11804199 1370.0057373
652.5690918 231.3109436 241.91317749 1474.68359375
1435.78051758 629.96569824 6305.32128906 11118.05566406
110.75399017 129.90248108 123.15526581 1222.25524902
345.72662354 383.10574341 261.18582153 162.55000305
20.83191109 2044.33349609 620.47149658 1304.79138184
4482.19433594 601.64862061 275.31658936]
>>> s, _ = env.reset()
>>> print('First reset', s[1: 1 + stock_dimension])
First reset [ 19.94001961 61.10700989 222.36936951 79.38145447 97.45033264
144.63755798 15.24585056 5.74729586 303.35714722 33.54574585
105.43174744 178.36380005 292.96624756 415.73022461 19.91226387
166.92402649 22.62735176 251.48953247 91.10407257 493.23080444
49.65197372 189.57377625 71.15727234 37.47664261 102.81951904
41.22008896 18.54531288 96.87306976 300.88070679 64.68822479
489.25418091 443.50683594 89.87454987 62.28596878 40.28764725
287.0161438 111.89730072 94.98454285 50.7267189 32.15766907
7.10770512 94.18431854 54.51696396 43.74747086 364.16394043
65.57043457 47.51332855]
>>> while True:
... s_, r, d, t, _ = env.step(env.action_space.sample())
... s = s_
... if d or t: break
...
>>> s, _ = env.reset()
>>> print('Second reset', s[1: 1 + stock_dimension])
Second reset [ 148.79171753 399.91134644 1342.78930664 1320.36901855
806.31970215 2465.26635742 851.19567871 3647.03857422
343.90393066 299.88269043 2462.59765625 543.18145752
1550.21484375 2458.41845703 1718.94885254 894.87109375
472.84255981 2033.78674316 1186.35107422 2243.98901367
201.33979797 1682.8092041 431.11804199 1370.0057373
652.5690918 231.3109436 241.91317749 1474.68359375
1435.78051758 629.96569824 6305.32128906 11118.05566406
110.75399017 129.90248108 123.15526581 1222.25524902
345.72662354 383.10574341 261.18582153 162.55000305
20.83191109 2044.33349609 620.47149658 1304.79138184
4482.19433594 601.64862061 275.31658936]
>>> env = StockTradingEnv(df = train, **env_kwargs)
>>> print('Day 0:', train.loc[0].close.values)
Day 0: [ 19.94001961 61.10700989 222.36936951 79.38145447 97.45033264
144.63755798 15.24585056 5.74729586 303.35714722 33.54574585
105.43174744 178.36380005 292.96624756 415.73022461 19.91226387
166.92402649 22.62735176 251.48953247 91.10407257 493.23080444
49.65197372 189.57377625 71.15727234 37.47664261 102.81951904
41.22008896 18.54531288 96.87306976 300.88070679 64.68822479
489.25418091 443.50683594 89.87454987 62.28596878 40.28764725
287.0161438 111.89730072 94.98454285 50.7267189 32.15766907
7.10770512 94.18431854 54.51696396 43.74747086 364.16394043
65.57043457 47.51332855]
>>> last_day = len(train.date.unique())
>>> print(f'Day {last_day}:', train.loc[last_day-1].close.values)
Day 2577: [ 148.79171753 399.91134644 1342.78930664 1320.36901855
806.31970215 2465.26635742 851.19567871 3647.03857422
343.90393066 299.88269043 2462.59765625 543.18145752
1550.21484375 2458.41845703 1718.94885254 894.87109375
472.84255981 2033.78674316 1186.35107422 2243.98901367
201.33979797 1682.8092041 431.11804199 1370.0057373
652.5690918 231.3109436 241.91317749 1474.68359375
1435.78051758 629.96569824 6305.32128906 11118.05566406
110.75399017 129.90248108 123.15526581 1222.25524902
345.72662354 383.10574341 261.18582153 162.55000305
20.83191109 2044.33349609 620.47149658 1304.79138184
4482.19433594 601.64862061 275.31658936]
>>> s, _ = env.reset()
>>> print('First reset', s[1: 1 + stock_dimension])
First reset [ 19.94001961 61.10700989 222.36936951 79.38145447 97.45033264
144.63755798 15.24585056 5.74729586 303.35714722 33.54574585
105.43174744 178.36380005 292.96624756 415.73022461 19.91226387
166.92402649 22.62735176 251.48953247 91.10407257 493.23080444
49.65197372 189.57377625 71.15727234 37.47664261 102.81951904
41.22008896 18.54531288 96.87306976 300.88070679 64.68822479
489.25418091 443.50683594 89.87454987 62.28596878 40.28764725
287.0161438 111.89730072 94.98454285 50.7267189 32.15766907
7.10770512 94.18431854 54.51696396 43.74747086 364.16394043
65.57043457 47.51332855]
>>> while True:
... s_, r, d, t, _ = env.step(env.action_space.sample())
... s = s_
... if d or t: break
...
>>> s, _ = env.reset()
>>> print('Second reset', s[1: 1 + stock_dimension])
Second reset [ 19.94001961 61.10700989 222.36936951 79.38145447 97.45033264
144.63755798 15.24585056 5.74729586 303.35714722 33.54574585
105.43174744 178.36380005 292.96624756 415.73022461 19.91226387
166.92402649 22.62735176 251.48953247 91.10407257 493.23080444
49.65197372 189.57377625 71.15727234 37.47664261 102.81951904
41.22008896 18.54531288 96.87306976 300.88070679 64.68822479
489.25418091 443.50683594 89.87454987 62.28596878 40.28764725
287.0161438 111.89730072 94.98454285 50.7267189 32.15766907
7.10770512 94.18431854 54.51696396 43.74747086 364.16394043
65.57043457 47.51332855]
Currently StockTradingEnv.reset() sets self.day to 0 after initializing state, which leads to incorrect state initialization as the state contains the stock prices of last day instead of the first day. This commit aims to fix this issue.