AI4Finance-Foundation / FinRL

FinRL: Financial Reinforcement Learning. 🔥
https://ai4finance.org
MIT License
9.65k stars 2.34k forks source link

DRLEnsembleAgent init missing num_stock_shares aggument #556

Closed khanhphan1311 closed 2 years ago

khanhphan1311 commented 2 years ago

EnsembleAgent call env_stocktrading.StockTradingEnv with these agguments (with no num_stock_shares):

StockTradingEnv(
                    trade_data,
                    self.stock_dim,
                    self.hmax,
                    self.initial_amount,
                    self.buy_cost_pct,
                    self.sell_cost_pct,
                    self.reward_scaling,
                    self.state_space,
                    self.action_space,
                    self.tech_indicator_list,
                    turbulence_threshold=turbulence_threshold,
                    initial=initial,
                    previous_state=last_state,
                    model_name=name,
                    mode="trade",
                    iteration=iter_num,
                    print_verbosity=self.print_verbosity,
                )

but StockTradingEnv is now update init with num_stock_shares:

class StockTradingEnv(gym.Env):
    """A stock trading environment for OpenAI gym"""

    metadata = {"render.modes": ["human"]}

    def __init__(
        self,
        df: pd.DataFrame,
        stock_dim: int,
        hmax: int,
        initial_amount: int,
        num_stock_shares: List[int],
        buy_cost_pct: List[float],
        sell_cost_pct: List[float],
        reward_scaling: float,
        state_space: int,
        action_space: int,
        tech_indicator_list: List[str],
        turbulence_threshold=None,
        risk_indicator_col="turbulence",
        make_plots: bool =False,
        print_verbosity=10,
        day=0,
        initial=True,
        previous_state=[],
        model_name="",
        mode="",
        iteration="",
    ):

I suggest to update init of EnsembleAgent for num_stock_shares as follow:

def __init__(
            self,
            df,
            train_period,
            val_test_period,
            rebalance_window,
            validation_window,
            stock_dim,
            hmax,
            initial_amount,
            num_stock_shares,
            buy_cost_pct,
            sell_cost_pct,
            reward_scaling,
            state_space,
            action_space,
            tech_indicator_list,
            print_verbosity,
    ):

        self.df = df
        self.train_period = train_period
        self.val_test_period = val_test_period

        self.unique_trade_date = df[
            (df.date > val_test_period[0]) & (df.date <= val_test_period[1])
            ].date.unique()
        self.rebalance_window = rebalance_window
        self.validation_window = validation_window

        self.stock_dim = stock_dim
        self.hmax = hmax
        self.initial_amount = initial_amount
        self.num_stock_shares = num_stock_shares
        self.buy_cost_pct = buy_cost_pct
        self.sell_cost_pct = sell_cost_pct
        self.reward_scaling = reward_scaling
        self.state_space = state_space
        self.action_space = action_space
        self.tech_indicator_list = tech_indicator_list
        self.print_verbosity = print_verbosity

Then we will call StockTradingEnv as:

StockTradingEnv(
                    trade_data,
                    self.stock_dim,
                    self.hmax,
                    self.initial_amount,
                    self.num_stock_shares,
                    self.buy_cost_pct,
                    self.sell_cost_pct,
                    self.reward_scaling,
                    self.state_space,
                    self.action_space,
                    self.tech_indicator_list,
                    turbulence_threshold=turbulence_threshold,
                    initial=initial,
                    previous_state=last_state,
                    model_name=name,
                    mode="trade",
                    iteration=iter_num,
                    print_verbosity=self.print_verbosity,
                )
zhumingpassional commented 2 years ago

We will update it in the next days.