AI4Finance-Foundation / FinRL

FinRL: Financial Reinforcement Learning. 🔥
https://ai4finance.org
MIT License
9.72k stars 2.36k forks source link

Reset self.day before initializing state, otherwise self.state is incorrectly initialized when reset is called. #1042

Closed AkashKarnatak closed 1 year ago

AkashKarnatak commented 1 year ago

Currently StockTradingEnv.reset() sets self.day to 0 after initializing state, which leads to incorrect state initialization as the state contains the stock prices of last day instead of the first day. This commit aims to fix this issue.

zhumingpassional commented 1 year ago

have you tested it? does it works?

AkashKarnatak commented 1 year ago

yes, here are the results:

Before commit

>>> env = StockTradingEnv(df = train, **env_kwargs)
>>> print('Day 0:', train.loc[0].close.values)
Day 0: [ 19.94001961  61.10700989 222.36936951  79.38145447  97.45033264
 144.63755798  15.24585056   5.74729586 303.35714722  33.54574585
 105.43174744 178.36380005 292.96624756 415.73022461  19.91226387
 166.92402649  22.62735176 251.48953247  91.10407257 493.23080444
  49.65197372 189.57377625  71.15727234  37.47664261 102.81951904
  41.22008896  18.54531288  96.87306976 300.88070679  64.68822479
 489.25418091 443.50683594  89.87454987  62.28596878  40.28764725
 287.0161438  111.89730072  94.98454285  50.7267189   32.15766907
   7.10770512  94.18431854  54.51696396  43.74747086 364.16394043
  65.57043457  47.51332855]

>>> last_day = len(train.date.unique())
>>> print(f'Day {last_day}:', train.loc[last_day-1].close.values)
Day 2577: [  148.79171753   399.91134644  1342.78930664  1320.36901855
   806.31970215  2465.26635742   851.19567871  3647.03857422
   343.90393066   299.88269043  2462.59765625   543.18145752
  1550.21484375  2458.41845703  1718.94885254   894.87109375
   472.84255981  2033.78674316  1186.35107422  2243.98901367
   201.33979797  1682.8092041    431.11804199  1370.0057373
   652.5690918    231.3109436    241.91317749  1474.68359375
  1435.78051758   629.96569824  6305.32128906 11118.05566406
   110.75399017   129.90248108   123.15526581  1222.25524902
   345.72662354   383.10574341   261.18582153   162.55000305
    20.83191109  2044.33349609   620.47149658  1304.79138184
  4482.19433594   601.64862061   275.31658936]

>>> s, _ = env.reset()
>>> print('First reset', s[1: 1 + stock_dimension])
First reset [ 19.94001961  61.10700989 222.36936951  79.38145447  97.45033264
 144.63755798  15.24585056   5.74729586 303.35714722  33.54574585
 105.43174744 178.36380005 292.96624756 415.73022461  19.91226387
 166.92402649  22.62735176 251.48953247  91.10407257 493.23080444
  49.65197372 189.57377625  71.15727234  37.47664261 102.81951904
  41.22008896  18.54531288  96.87306976 300.88070679  64.68822479
 489.25418091 443.50683594  89.87454987  62.28596878  40.28764725
 287.0161438  111.89730072  94.98454285  50.7267189   32.15766907
   7.10770512  94.18431854  54.51696396  43.74747086 364.16394043
  65.57043457  47.51332855]

>>> while True:
...    s_, r, d, t, _ = env.step(env.action_space.sample())
...    s = s_
...    if d or t: break
...
>>> s, _ = env.reset()
>>> print('Second reset', s[1: 1 + stock_dimension])
Second reset [  148.79171753   399.91134644  1342.78930664  1320.36901855
   806.31970215  2465.26635742   851.19567871  3647.03857422
   343.90393066   299.88269043  2462.59765625   543.18145752
  1550.21484375  2458.41845703  1718.94885254   894.87109375
   472.84255981  2033.78674316  1186.35107422  2243.98901367
   201.33979797  1682.8092041    431.11804199  1370.0057373
   652.5690918    231.3109436    241.91317749  1474.68359375
  1435.78051758   629.96569824  6305.32128906 11118.05566406
   110.75399017   129.90248108   123.15526581  1222.25524902
   345.72662354   383.10574341   261.18582153   162.55000305
    20.83191109  2044.33349609   620.47149658  1304.79138184
  4482.19433594   601.64862061   275.31658936]

After commit

>>> env = StockTradingEnv(df = train, **env_kwargs)
>>> print('Day 0:', train.loc[0].close.values)
Day 0: [ 19.94001961  61.10700989 222.36936951  79.38145447  97.45033264
 144.63755798  15.24585056   5.74729586 303.35714722  33.54574585
 105.43174744 178.36380005 292.96624756 415.73022461  19.91226387
 166.92402649  22.62735176 251.48953247  91.10407257 493.23080444
  49.65197372 189.57377625  71.15727234  37.47664261 102.81951904
  41.22008896  18.54531288  96.87306976 300.88070679  64.68822479
 489.25418091 443.50683594  89.87454987  62.28596878  40.28764725
 287.0161438  111.89730072  94.98454285  50.7267189   32.15766907
   7.10770512  94.18431854  54.51696396  43.74747086 364.16394043
  65.57043457  47.51332855]

>>> last_day = len(train.date.unique())
>>> print(f'Day {last_day}:', train.loc[last_day-1].close.values)
Day 2577: [  148.79171753   399.91134644  1342.78930664  1320.36901855
   806.31970215  2465.26635742   851.19567871  3647.03857422
   343.90393066   299.88269043  2462.59765625   543.18145752
  1550.21484375  2458.41845703  1718.94885254   894.87109375
   472.84255981  2033.78674316  1186.35107422  2243.98901367
   201.33979797  1682.8092041    431.11804199  1370.0057373
   652.5690918    231.3109436    241.91317749  1474.68359375
  1435.78051758   629.96569824  6305.32128906 11118.05566406
   110.75399017   129.90248108   123.15526581  1222.25524902
   345.72662354   383.10574341   261.18582153   162.55000305
    20.83191109  2044.33349609   620.47149658  1304.79138184
  4482.19433594   601.64862061   275.31658936]

>>> s, _ = env.reset()
>>> print('First reset', s[1: 1 + stock_dimension])
First reset [ 19.94001961  61.10700989 222.36936951  79.38145447  97.45033264
 144.63755798  15.24585056   5.74729586 303.35714722  33.54574585
 105.43174744 178.36380005 292.96624756 415.73022461  19.91226387
 166.92402649  22.62735176 251.48953247  91.10407257 493.23080444
  49.65197372 189.57377625  71.15727234  37.47664261 102.81951904
  41.22008896  18.54531288  96.87306976 300.88070679  64.68822479
 489.25418091 443.50683594  89.87454987  62.28596878  40.28764725
 287.0161438  111.89730072  94.98454285  50.7267189   32.15766907
   7.10770512  94.18431854  54.51696396  43.74747086 364.16394043
  65.57043457  47.51332855]

>>> while True:
...    s_, r, d, t, _ = env.step(env.action_space.sample())
...    s = s_
...    if d or t: break
...
>>> s, _ = env.reset()
>>> print('Second reset', s[1: 1 + stock_dimension])
Second reset [ 19.94001961  61.10700989 222.36936951  79.38145447  97.45033264
 144.63755798  15.24585056   5.74729586 303.35714722  33.54574585
 105.43174744 178.36380005 292.96624756 415.73022461  19.91226387
 166.92402649  22.62735176 251.48953247  91.10407257 493.23080444
  49.65197372 189.57377625  71.15727234  37.47664261 102.81951904
  41.22008896  18.54531288  96.87306976 300.88070679  64.68822479
 489.25418091 443.50683594  89.87454987  62.28596878  40.28764725
 287.0161438  111.89730072  94.98454285  50.7267189   32.15766907
   7.10770512  94.18431854  54.51696396  43.74747086 364.16394043
  65.57043457  47.51332855]