tensortrade-org / tensortrade

An open source reinforcement learning framework for training, evaluating, and deploying robust trading agents.
https://discord.gg/ZZ7BGWh
Apache License 2.0
4.44k stars 1.01k forks source link

Not possible to use GPU with Tensortrade environments #382

Open avacaondata opened 2 years ago

avacaondata commented 2 years ago

System information

Describe the current behavior

When trying to run this example: https://www.tensortrade.org/en/latest/examples/train_and_evaluate_using_ray.html (the official tutorial on how to use tensorboard), I cannot do it with GPU enabled. However, if I run the ray part but using Cartpole as the environment, with the same libraries environment and everything, it can use GPUs. Therefore this has something to do with tensorboard explicitly, not with ray or pytorch. The error is the following

 pid=2444) ray::PPOTrainer.__init__() (pid=2444, ip=127.0.0.1, repr=PPOTrainer)
 pid=2444)   File "python\ray\_raylet.pyx", line 633, in ray._raylet.execute_task
 pid=2444)   File "python\ray\_raylet.pyx", line 674, in ray._raylet.execute_task
 pid=2444)   File "python\ray\_raylet.pyx", line 640, in ray._raylet.execute_task
 pid=2444)   File "python\ray\_raylet.pyx", line 644, in ray._raylet.execute_task
 pid=2444)   File "python\ray\_raylet.pyx", line 593, in ray._raylet.execute_task.function_executor
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\_private\function_manager.py", line 648, in actor_method_executor
 pid=2444)     return method(__ray_actor, *args, **kwargs)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
 pid=2444)     return method(self, *_args, **_kwargs)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\agents\trainer.py", line 741, in __init__
 pid=2444)     super().__init__(config, logger_creator, remote_checkpoint_dir,
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\tune\trainable.py", line 124, in __init__
 pid=2444)     self.setup(copy.deepcopy(self.config))
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
 pid=2444)     return method(self, *_args, **_kwargs)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\agents\trainer.py", line 846, in setup
 pid=2444)     self.workers = self._make_workers(
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
 pid=2444)     return method(self, *_args, **_kwargs)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\agents\trainer.py", line 1971, in _make_workers
 pid=2444)     return WorkerSet(
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\evaluation\worker_set.py", line 123, in __init__
 pid=2444)     self._local_worker = self._make_worker(
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\evaluation\worker_set.py", line 499, in _make_worker
 pid=2444)     worker = cls(
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 586, in __init__
 pid=2444)     self._build_policy_map(
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1569, in _build_policy_map
 pid=2444)     self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\policy\policy_map.py", line 143, in create_policy
 pid=2444)     self[policy_id] = class_(observation_space, action_space,
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\agents\ppo\ppo_torch_policy.py", line 50, in __init__
 pid=2444)     self._initialize_loss_from_dummy_batch()
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\policy\policy.py", line 832, in _initialize_loss_from_dummy_batch
 pid=2444)     self.compute_actions_from_input_dict(
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\policy\torch_policy.py", line 294, in compute_actions_from_input_dict
 pid=2444)     return self._compute_action_helper(input_dict, state_batches,
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\utils\threading.py", line 21, in wrapper
 pid=2444)     return func(self, *a, **k)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\policy\torch_policy.py", line 934, in _compute_action_helper
 pid=2444)     dist_inputs, state_out = self.model(input_dict, state_batches,
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\models\modelv2.py", line 243, in __call__
 pid=2444)     res = self.forward(restored, state or [], seq_lens)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\models\torch\complex_input_net.py", line 193, in forward
 pid=2444)     nn_out, _ = self.flatten[i](SampleBatch({
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\models\modelv2.py", line 243, in __call__
 pid=2444)     res = self.forward(restored, state or [], seq_lens)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\models\torch\fcnet.py", line 124, in forward
 pid=2444)     self._features = self._hidden_layers(self._last_flat_in)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
 pid=2444)     return forward_call(*input, **kwargs)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\modules\container.py", line 141, in forward
 pid=2444)     input = module(input)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
 pid=2444)     return forward_call(*input, **kwargs)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\models\torch\misc.py", line 160, in forward
 pid=2444)     return self._model(x)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
 pid=2444)     return forward_call(*input, **kwargs)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\modules\container.py", line 141, in forward
 pid=2444)     input = module(input)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
 pid=2444)     return forward_call(*input, **kwargs)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\modules\linear.py", line 103, in forward
 pid=2444)     return F.linear(input, self.weight, self.bias)
 pid=2444)   File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\functional.py", line 1848, in linear
 pid=2444)     return torch._C._nn.linear(input, weight, bias)
 pid=2444) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_addmm)

Describe the expected behavior The expected behavior is that tensortrade envs return the data in the same format as other envs such as cartpole, so that tensors can be sent to the GPU.

Code to reproduce the issue

import yfinance
import pandas_ta  #noqa
import pandas as pd
from tensortrade.feed.core import DataFeed, Stream
from tensortrade.oms.instruments import Instrument
from tensortrade.oms.exchanges import Exchange, ExchangeOptions
from tensortrade.oms.services.execution.simulated import execute_order
from tensortrade.oms.wallets import Wallet, Portfolio
import tensortrade.env.default as default
import ray
import os
from ray import tune
from ray.tune.registry import register_env

TICKER = 'AAPL'  # TODO: replace this with your own ticker
TRAIN_START_DATE = '2021-02-09'  # TODO: replace this with your own start date
TRAIN_END_DATE = '2021-09-30'  # TODO: replace this with your own end date
EVAL_START_DATE = '2021-10-01'  # TODO: replace this with your own end date
EVAL_END_DATE = '2021-11-12'  # TODO: replace this with your own end date

yf_ticker = yfinance.Ticker(ticker=TICKER)

df_training = yf_ticker.history(start=TRAIN_START_DATE, end=TRAIN_END_DATE, interval='60m')
df_training.drop(['Dividends', 'Stock Splits'], axis=1, inplace=True)
df_training["Volume"] = df_training["Volume"].astype(int)
df_training.ta.log_return(append=True, length=16)
df_training.ta.rsi(append=True, length=14)
df_training.ta.macd(append=True, fast=12, slow=26)
df_training.to_csv('training.csv', index=True)

df_evaluation = yf_ticker.history(start=EVAL_START_DATE, end=EVAL_END_DATE, interval='60m')
df_evaluation.drop(['Dividends', 'Stock Splits'], axis=1, inplace=True)
df_evaluation["Volume"] = df_evaluation["Volume"].astype(int)
df_evaluation.ta.log_return(append=True, length=16)
df_evaluation.ta.rsi(append=True, length=14)
df_evaluation.ta.macd(append=True, fast=12, slow=26)
df_evaluation.to_csv('evaluation.csv', index=True)

def create_env(config):
    dataset = pd.read_csv(filepath_or_buffer=config["csv_filename"], parse_dates=['Datetime']).fillna(method='backfill').fillna(method='ffill')
    ttse_commission = 0.0035  # TODO: adjust according to your commission percentage, if present
    price = Stream.source(list(dataset["Close"]), dtype="float").rename("USD-TTRD")
    ttse_options = ExchangeOptions(commission=ttse_commission)
    ttse_exchange = Exchange("TTSE", service=execute_order, options=ttse_options)(price)

 # Instruments, Wallets and Portfolio
    USD = Instrument("USD", 2, "US Dollar")
    TTRD = Instrument("TTRD", 2, "TensorTrade Corp")
    cash = Wallet(ttse_exchange, 1000 * USD)  # This is the starting cash we are going to use
    asset = Wallet(ttse_exchange, 0 * TTRD)  # And we will start owning 0 stocks of TTRD
    portfolio = Portfolio(USD, [cash, asset])

    # Renderer feed
    renderer_feed = DataFeed([
        Stream.source(list(dataset["Datetime"])).rename("date"),
        Stream.source(list(dataset["Open"]), dtype="float").rename("open"),
        Stream.source(list(dataset["High"]), dtype="float").rename("high"),
        Stream.source(list(dataset["Low"]), dtype="float").rename("low"),
        Stream.source(list(dataset["Close"]), dtype="float").rename("close"),
        Stream.source(list(dataset["Volume"]), dtype="float").rename("volume")
    ])

    features = []
    for c in dataset.columns[1:]:
        s = Stream.source(list(dataset[c]), dtype="float").rename(dataset[c].name)
        features += [s]
    feed = DataFeed(features)
    feed.compile()

    reward_scheme = default.rewards.SimpleProfit(window_size=config["reward_window_size"])
    action_scheme = default.actions.BSH(cash=cash, asset=asset)

    env = default.create(
            feed=feed,
            portfolio=portfolio,
            action_scheme=action_scheme,
            reward_scheme=reward_scheme,
            renderer_feed=renderer_feed,
            renderer=[],
            window_size=config["window_size"],
            max_allowed_loss=config["max_allowed_loss"]
        )

    return env

# Let's define some tuning parameters
FC_SIZE = tune.grid_search([[256, 256], [1024], [128, 64, 32]])  # Those are the alternatives that ray.tune will try...
LEARNING_RATE = tune.grid_search([0.001, 0.0005, 0.00001])  # ... and they will be combined with these ones ...
MINIBATCH_SIZE = tune.grid_search([5, 10, 20])  # ... and these ones, in a cartesian product.

# Get the current working directory
cwd = os.getcwd()

# Initialize Ray
ray.init(num_gpus=1)  # There are *LOTS* of initialization parameters, like specifying the maximum number of CPUs\GPUs to allocate. For now just leave it alone.

# Register our environment, specifying which is the environment creation function
register_env("MyTrainingEnv", create_env)

# Specific configuration keys that will be used during training
env_config_training = {
    "window_size": 14,  # We want to look at the last 14 samples (hours)
    "reward_window_size": 7,  # And calculate reward based on the actions taken in the next 7 hours
    "max_allowed_loss": 0.10,  # If it goes past 10% loss during the iteration, we don't want to waste time on a "loser".
    "csv_filename": os.path.join(cwd, 'training.csv'),  # The variable that will be used to differentiate training and validation datasets
}
# Specific configuration keys that will be used during evaluation (only the overridden ones)
env_config_evaluation = {
    "max_allowed_loss": 1.00,  # During validation runs we want to see how bad it would go. Even up to 100% loss.
    "csv_filename": os.path.join(cwd, 'evaluation.csv'),  # The variable that will be used to differentiate training and validation datasets
}

analysis = tune.run(
    run_or_experiment="PPO",  # We'll be using the builtin PPO agent in RLLib
    name="MyExperiment1",
    metric='episode_reward_mean',
    mode='max',
    stop={
        "training_iteration": 5  # Let's do 5 steps for each hyperparameter combination
    },
    config={
        "env": "MyTrainingEnv",
        "env_config": env_config_training,  # The dictionary we built before
        "log_level": "WARNING",
        "framework": "torch",
        "ignore_worker_failures": True,
        "num_workers": 1,  # One worker per agent. You can increase this but it will run fewer parallel trainings.
        "num_envs_per_worker": 1,
        "num_gpus": 1,  # I yet have to understand if using a GPU is worth it, for our purposes, but I think it's not. This way you can train on a non-gpu enabled system.
        "clip_rewards": True,
        "lr": LEARNING_RATE,  # Hyperparameter grid search defined above
        "gamma": 0.50,  # This can have a big impact on the result and needs to be properly tuned (range is 0 to 1)
        "observation_filter": "MeanStdFilter",
        "model": {
            "fcnet_hiddens": FC_SIZE,  # Hyperparameter grid search defined above
        },
        "sgd_minibatch_size": MINIBATCH_SIZE,  # Hyperparameter grid search defined above
        "evaluation_interval": 1,  # Run evaluation on every iteration
        "evaluation_config": {
            "env_config": env_config_evaluation,  # The dictionary we built before (only the overriding keys to use in evaluation)
            "explore": False,  # We don't want to explore during evaluation. All actions have to be repeatable.
        },
    },
    num_samples=1,  # Have one sample for each hyperparameter combination. You can have more to average out randomness.
    keep_checkpoints_num=10,  # Keep the last 2 checkpoints
    checkpoint_freq=1,  # Do a checkpoint on each iteration (slower but you can pick more finely the checkpoint to use later)
)

Other info / logs

Failure # 1 (occurred at 2022-01-27_14-48-20)
Traceback (most recent call last):
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\tune\trial_runner.py", line 886, in _process_trial
    results = self.trial_executor.fetch_result(trial)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\tune\ray_trial_executor.py", line 675, in fetch_result
    result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\_private\client_mode_hook.py", line 105, in wrapper
    return func(*args, **kwargs)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\worker.py", line 1760, in get
    raise value
ray.exceptions.RayActorError: The actor died because of an error raised in its creation task, ray::PPOTrainer.__init__() (pid=13880, ip=127.0.0.1, repr=PPOTrainer)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
    return method(self, *_args, **_kwargs)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\agents\trainer.py", line 948, in _init
    raise NotImplementedError
NotImplementedError

During handling of the above exception, another exception occurred:

ray::PPOTrainer.__init__() (pid=13880, ip=127.0.0.1, repr=PPOTrainer)
  File "python\ray\_raylet.pyx", line 633, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 674, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 640, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 644, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 593, in ray._raylet.execute_task.function_executor
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\_private\function_manager.py", line 648, in actor_method_executor
    return method(__ray_actor, *args, **kwargs)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
    return method(self, *_args, **_kwargs)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\agents\trainer.py", line 741, in __init__
    super().__init__(config, logger_creator, remote_checkpoint_dir,
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\tune\trainable.py", line 124, in __init__
    self.setup(copy.deepcopy(self.config))
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
    return method(self, *_args, **_kwargs)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\agents\trainer.py", line 846, in setup
    self.workers = self._make_workers(
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\util\tracing\tracing_helper.py", line 451, in _resume_span
    return method(self, *_args, **_kwargs)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\agents\trainer.py", line 1971, in _make_workers
    return WorkerSet(
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\evaluation\worker_set.py", line 123, in __init__
    self._local_worker = self._make_worker(
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\evaluation\worker_set.py", line 499, in _make_worker
    worker = cls(
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 586, in __init__
    self._build_policy_map(
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1569, in _build_policy_map
    self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\policy\policy_map.py", line 143, in create_policy
    self[policy_id] = class_(observation_space, action_space,
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\agents\ppo\ppo_torch_policy.py", line 50, in __init__
    self._initialize_loss_from_dummy_batch()
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\policy\policy.py", line 832, in _initialize_loss_from_dummy_batch
    self.compute_actions_from_input_dict(
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\policy\torch_policy.py", line 294, in compute_actions_from_input_dict
    return self._compute_action_helper(input_dict, state_batches,
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\utils\threading.py", line 21, in wrapper
    return func(self, *a, **k)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\policy\torch_policy.py", line 934, in _compute_action_helper
    dist_inputs, state_out = self.model(input_dict, state_batches,
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\models\modelv2.py", line 243, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\models\torch\attention_net.py", line 349, in forward
    wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\models\torch\complex_input_net.py", line 193, in forward
    nn_out, _ = self.flatten[i](SampleBatch({
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\models\modelv2.py", line 243, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\models\torch\fcnet.py", line 124, in forward
    self._features = self._hidden_layers(self._last_flat_in)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\modules\container.py", line 141, in forward
    input = module(input)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\ray\rllib\models\torch\misc.py", line 160, in forward
    return self._model(x)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\modules\container.py", line 141, in forward
    input = module(input)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\modules\linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
  File "C:\Users\Usuario\anaconda3\envs\cryptorl\lib\site-packages\torch\nn\functional.py", line 1848, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_addmm)

@carlogrisetti @notadamking

carlogrisetti commented 2 years ago

It is sadly a known issue. The GPU improvements would be marginal thou, since the GPU would only be used in computing the SGD (at least this is what I understood up until now).

If anyone wants to tackle this issue and try and figure out what's breaking, it would be nice for sure!

bhavithran1 commented 2 years ago

I have the same problrm as well.Hopefully this can be solved in the future....