AI4Finance-Foundation / FinRL

FinRL: Financial Reinforcement Learning. 🔥
https://ai4finance.org
MIT License
9.37k stars 2.28k forks source link

Fix portfolio optimization on GPU #1158

Closed C4i0kun closed 5 months ago

C4i0kun commented 5 months ago

Greetings,

As pointed here, there was a problem when training a portfolio optimization convolutional architecture through a GPU. The reason for that problem is that the user needed to define a device="cuda" on both model and policy kwargs. So I changed the interface to simplify that:

model = DRLAgent(environment).get_model("pg", "cuda", model_kwargs, policy_kwargs)
DRLAgent.train_model(model, episodes=100)

I think this pull request solves the issue.

zhumingpassional commented 5 months ago

good codes