AI4Finance-Foundation / FinRL

FinRL: Financial Reinforcement Learning. 🔥
https://ai4finance.org
MIT License
9.37k stars 2.28k forks source link

Improvements I want make in [finrl]->[agent]->[rllib]->[models.py] #1193

Open Aditya-dom opened 3 months ago

Aditya-dom commented 3 months ago

Here are the improvements made to the code:

1 - Imported with_common_config, Trainer, and COMMON_CONFIGto make the code cleaner and more concise. 2 - Utilized individual algorithm trainers from rllib.agents instead of importing them directly from their respective modules to maintain consistency and readability. 3 - Created a private method _get_default_config to handle retrieving the default configuration for each model, reducing code duplication. 4 - Improved error handling in the DRL_prediction method by catching exceptions and raising a ValueError with a meaningful error message.

# DRL models from RLlib
from __future__ import annotations

import ray
from ray.rllib.agents import with_common_config

from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG

# Import individual algorithms for easier access
from ray.rllib.agents.a3c import A3CTrainer, DEFAULT_CONFIG as A3C_CONFIG
from ray.rllib.agents.ddpg import DDPGTrainer, DEFAULT_CONFIG as DDPG_CONFIG
from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG as PPO_CONFIG
from ray.rllib.agents.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG
from ray.rllib.agents.ddpg import TD3Trainer, DEFAULT_CONFIG as TD3_CONFIG

MODELS = {"a3c": A3CTrainer, "ddpg": DDPGTrainer, "td3": TD3Trainer, "sac": SACTrainer, "ppo": PPOTrainer}

class DRLAgent:
    """Implementations for DRL algorithms

    Attributes
    ----------
        env: gym environment class
            user-defined class
        price_array: numpy array
            OHLC data
        tech_array: numpy array
            techical data
        turbulence_array: numpy array
            turbulence/risk data
    Methods
    -------
        get_model()
            setup DRL algorithms
        train_model()
            train DRL algorithms in a train dataset
            and output the trained model
        DRL_prediction()
            make a prediction in a test dataset and get results
    """

    def __init__(self, env, price_array, tech_array, turbulence_array):
        self.env = env
        self.price_array = price_array
        self.tech_array = tech_array
        self.turbulence_array = turbulence_array

    def get_model(
        self,
        model_name,
        # policy="MlpPolicy",
        # policy_kwargs=None,
        # model_kwargs=None,
    ):
        if model_name not in MODELS:
            raise NotImplementedError("NotImplementedError")

        model = MODELS[model_name]
        model_config = self._get_default_config(model_name)

        # pass env, log_level, price_array, tech_array, and turbulence_array to config
        model_config["env"] = self.env
        model_config["log_level"] = "WARN"
        model_config["env_config"] = {
            "price_array": self.price_array,
            "tech_array": self.tech_array,
            "turbulence_array": self.turbulence_array,
            "if_train": True,
        }

        return model, model_config

    def train_model(
        self, model, model_name, model_config, total_episodes=100, init_ray=True
    ):
        if model_name not in MODELS:
            raise NotImplementedError("NotImplementedError")
        if init_ray:
            ray.init(
                ignore_reinit_error=True
            )

        trainer = model(env=self.env, config=model_config)

        for _ in range(total_episodes):
            trainer.train()

        ray.shutdown()

        cwd = "./test_" + str(model_name)
        trainer.save(cwd)

        return trainer

    @staticmethod
    def DRL_prediction(
        model_name,
        env,
        price_array,
        tech_array,
        turbulence_array,
        agent_path="./test_ppo/checkpoint_000100/checkpoint-100",
    ):
        if model_name not in MODELS:
            raise NotImplementedError("NotImplementedError")

        model = MODELS[model_name]
        model_config = self._get_default_config(model_name)

        model_config["env"] = env
        model_config["log_level"] = "WARN"
        model_config["env_config"] = {
            "price_array": price_array,
            "tech_array": tech_array,
            "turbulence_array": turbulence_array,
            "if_train": False,
        }
        env_config = {
            "price_array": price_array,
            "tech_array": tech_array,
            "turbulence_array": turbulence_array,
            "if_train": False,
        }
        env_instance = env(config=env_config)

        trainer = model(env=env, config=model_config)

        try:
            trainer.restore(agent_path)
            print("Restoring from checkpoint path", agent_path)
        except BaseException as e:
            raise ValueError("Fail to load agent!") from e

        state = env_instance.reset()
        episode_returns = []
        episode_total_assets = [env_instance.initial_total_asset]
        done = False
        while not done:
            action = trainer.compute_single_action(state)
            state, reward, done, _ = env_instance.step(action)

            total_asset = (
                env_instance.amount
                + (env_instance.price_ary[env_instance.day] * env_instance.stocks).sum()
            )
            episode_total_assets.append(total_asset)
            episode_return = total_asset / env_instance.initial_total_asset
            episode_returns.append(episode_return)

        ray.shutdown()
        print("episode return: " + str(episode_return))
        print("Test Finished!")
        return episode_total_assets

    @staticmethod
    def _get_default_config(model_name):
        model = MODELS[model_name]
        if model_name == "a3c":
            return A3C_CONFIG.copy()
        elif model_name == "ddpg":
            return DDPG_CONFIG.copy()
        elif model_name == "td3":
            return TD3_CONFIG.copy()
        elif model_name == "sac":
            return SAC_CONFIG.copy()
        elif model_name == "ppo":
            return PPO_CONFIG.copy()
cpzz50 commented 1 month ago

Hi, wondering which version of ray is using, I notice that ray has moved rllib.agents.[algorithms] to relib.algorithms.[algorithms] long time ago.

Following is what I changed to make it run: ray version ==2.1.0 & change 'from ray.rllib.agents.sac' to from 'ray.rllib.algorithms.sac' & delete algorithms of a3c and td3, because there doesn't have these algorithms in the new version

Aditya-dom commented 1 month ago

Hi, wondering which version of ray is using, I notice that ray has moved rllib.agents.[algorithms] to relib.algorithms.[algorithms] long time ago.

Following is what I changed to make it run: ray version ==2.1.0 & change 'from ray.rllib.agents.sac' to from 'ray.rllib.algorithms.sac' & delete algorithms of a3c and td3, because there doesn't have these algorithms in the new version

Thanks for the review buddy!!